File: Linker\TypeMapInfo.cs
Web Access
Project: src\src\tools\illink\src\linker\Mono.Linker.csproj (illink)
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
 
//
// TypeMapInfo.cs
//
// Author:
//   Jb Evain (jbevain@novell.com)
//
// (C) 2009 Novell, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//
 
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Mono.Cecil;
 
namespace Mono.Linker
{
 
	public class TypeMapInfo
	{
		readonly HashSet<AssemblyDefinition> assemblies = new HashSet<AssemblyDefinition> ();
		readonly LinkContext context;
		protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> base_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
		protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> override_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
		protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> default_interface_implementations = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
 
		public TypeMapInfo (LinkContext context)
		{
			this.context = context;
		}
 
		public void EnsureProcessed (AssemblyDefinition assembly)
		{
			if (!assemblies.Add (assembly))
				return;
 
			foreach (TypeDefinition type in assembly.MainModule.Types)
				MapType (type);
		}
 
		public ICollection<MethodDefinition> MethodsWithOverrideInformation => override_methods.Keys;
 
		/// <summary>
		/// Returns a list of all known methods that override <paramref name="method"/>. The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet
		/// </summary>
		public List<OverrideInformation>? GetOverrides (MethodDefinition method)
		{
			EnsureProcessed (method.Module.Assembly);
			override_methods.TryGetValue (method, out List<OverrideInformation>? overrides);
			return overrides;
		}
 
		/// <summary>
		/// Returns all base methods that <paramref name="method"/> overrides.
		/// This includes the closest overridden virtual method on <paramref name="method"/>'s base types
		/// methods on an interface that <paramref name="method"/>'s declaring type implements,
		/// and methods an interface implemented by a derived type of <paramref name="method"/>'s declaring type if the derived type uses <paramref name="method"/> as the implementing method.
		/// The list may be incomplete if there are derived types in assemblies that havent been processed yet that use <paramref name="method"/> to implement an interface.
		/// </summary>
		public List<OverrideInformation>? GetBaseMethods (MethodDefinition method)
		{
			EnsureProcessed (method.Module.Assembly);
			base_methods.TryGetValue (method, out List<OverrideInformation>? bases);
			return bases;
		}
 
		/// <summary>
		/// Returns a list of all default interface methods that implement <paramref name="method"/> for a type.
		/// ImplementingType is the type that implements the interface,
		/// InterfaceImpl is the <see cref="InterfaceImplementation" /> for the interface <paramref name="method" /> is declared on, and
		/// DefaultInterfaceMethod is the method that implements <paramref name="method"/>.
		/// </summary>
		/// <param name="method">The interface method to find default implementations for</param>
		public IEnumerable<OverrideInformation>? GetDefaultInterfaceImplementations (MethodDefinition baseMethod)
		{
			default_interface_implementations.TryGetValue (baseMethod, out var ret);
			return ret;
		}
 
		public void AddBaseMethod (MethodDefinition method, MethodDefinition @base, InterfaceImplementor? interfaceImplementor)
		{
			base_methods.AddToList (method, new OverrideInformation (@base, method, interfaceImplementor));
		}
 
		public void AddOverride (MethodDefinition @base, MethodDefinition @override, InterfaceImplementor? interfaceImplementor = null)
		{
			override_methods.AddToList (@base, new OverrideInformation (@base, @override, interfaceImplementor));
		}
 
		public void AddDefaultInterfaceImplementation (MethodDefinition @base, InterfaceImplementor interfaceImplementor, MethodDefinition defaultImplementationMethod)
		{
			Debug.Assert(@base.DeclaringType.IsInterface);
			default_interface_implementations.AddToList (@base, new OverrideInformation (@base, defaultImplementationMethod, interfaceImplementor));
		}
 
		Dictionary<TypeDefinition, List<(TypeReference, List<InterfaceImplementation>)>> interfaces = new ();
		protected virtual void MapType (TypeDefinition type)
		{
			MapVirtualMethods (type);
			MapInterfaceMethodsInTypeHierarchy (type);
			interfaces[type] = GetRecursiveInterfaceImplementations (type);
 
			if (!type.HasNestedTypes)
				return;
 
			foreach (var nested in type.NestedTypes)
				MapType (nested);
		}
 
		internal List<(TypeReference InterfaceType, List<InterfaceImplementation> ImplementationChain)>? GetRecursiveInterfaces (TypeDefinition type)
		{
			EnsureProcessed(type.Module.Assembly);
			if (interfaces.TryGetValue (type, out var value))
				return value;
			return null;
		}
 
		List<(TypeReference InterfaceType, List<InterfaceImplementation> ImplementationChain)> GetRecursiveInterfaceImplementations (TypeDefinition type)
		{
			List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain = new ();
 
			AddRecursiveInterfaces (type, [], firstImplementationChain, context);
			Debug.Assert (firstImplementationChain.All (kvp => context.Resolve (kvp.Item1) == context.Resolve (kvp.Item2.Last ().InterfaceType)));
 
			return firstImplementationChain;
 
			static void AddRecursiveInterfaces (TypeReference typeRef, IEnumerable<InterfaceImplementation> pathToType, List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain, LinkContext Context)
			{
				var type = Context.TryResolve (typeRef);
				// If we can't resolve the interface type we can't find recursive interfaces
				if (type is null)
					return;
				// Get all explicit interfaces of this type
				foreach (var iface in type.Interfaces) {
					var interfaceType = iface.InterfaceType.InflateFrom (typeRef as IGenericInstance);
					if (!firstImplementationChain.Any (i => TypeReferenceEqualityComparer.AreEqual (i.Item1, interfaceType, Context))) {
						firstImplementationChain.Add ((interfaceType, pathToType.Append (iface).ToList ()));
					}
				}
 
				// Recursive interfaces after all direct interfaces to preserve Inherit/Implement tree order
				foreach (var iface in type.Interfaces) {
					var ifaceDirectlyOnType = iface.InterfaceType.InflateFrom (typeRef as IGenericInstance);
					AddRecursiveInterfaces (ifaceDirectlyOnType, pathToType.Append (iface), firstImplementationChain, Context);
				}
			}
		}
 
		void MapInterfaceMethodsInTypeHierarchy (TypeDefinition type)
		{
			if (!type.HasInterfaces)
				return;
 
			// Foreach interface and for each newslot virtual method on the interface, try
			// to find the method implementation and record it.
			foreach (var interfaceImpl in type.GetInflatedInterfaces (context)) {
				foreach (MethodReference interfaceMethod in interfaceImpl.InflatedInterface.GetMethods (context)) {
					MethodDefinition? resolvedInterfaceMethod = context.TryResolve (interfaceMethod);
					if (resolvedInterfaceMethod == null)
						continue;
 
					// TODO-NICE: if the interface method is implemented explicitly (with an override),
					// we shouldn't need to run the below logic. This results in ILLink potentially
					// keeping more methods than needed.
 
					if (!resolvedInterfaceMethod.IsVirtual
						|| resolvedInterfaceMethod.IsFinal)
						continue;
 
					// Static methods on interfaces must be implemented only via explicit method-impl record
					// not by a signature match. So there's no point in running this logic for static methods.
					if (!resolvedInterfaceMethod.IsStatic) {
						// Try to find an implementation with a name/sig match on the current type
						MethodDefinition? exactMatchOnType = TryMatchMethod (type, interfaceMethod);
						if (exactMatchOnType != null) {
							AnnotateMethods (resolvedInterfaceMethod, exactMatchOnType, new (type, interfaceImpl.OriginalImpl, resolvedInterfaceMethod.DeclaringType, context));
							continue;
						}
 
						// Next try to find an implementation with a name/sig match in the base hierarchy
						var @base = GetBaseMethodInTypeHierarchy (type, interfaceMethod);
						if (@base != null) {
							AnnotateMethods (resolvedInterfaceMethod, @base, new (type, interfaceImpl.OriginalImpl, resolvedInterfaceMethod.DeclaringType, context));
							continue;
						}
					}
 
					// Look for a default implementation last.
					FindAndAddDefaultInterfaceImplementations (type, type, resolvedInterfaceMethod, interfaceImpl.OriginalImpl);
				}
			}
		}
 
		void MapVirtualMethods (TypeDefinition type)
		{
			if (!type.HasMethods)
				return;
 
			foreach (MethodDefinition method in type.Methods) {
				// We do not proceed unless a method is virtual or is static
				// A static method with a .override could be implementing a static interface method
				if (!(method.IsStatic || method.IsVirtual))
					continue;
 
				if (method.IsVirtual)
					MapVirtualMethod (method);
 
				if (method.HasOverrides)
					MapOverrides (method);
			}
		}
 
		void MapVirtualMethod (MethodDefinition method)
		{
			MethodDefinition? @base = GetBaseMethodInTypeHierarchy (method);
			if (@base == null)
				return;
 
			Debug.Assert(!@base.DeclaringType.IsInterface);
 
			AnnotateMethods (@base, method);
		}
 
		void MapOverrides (MethodDefinition method)
		{
			foreach (MethodReference baseMethodRef in method.Overrides) {
				MethodDefinition? baseMethod = context.TryResolve (baseMethodRef);
				if (baseMethod == null)
					continue;
				if (baseMethod.DeclaringType.IsInterface) {
					AnnotateMethods (baseMethod, method, InterfaceImplementor.Create (method.DeclaringType, baseMethod.DeclaringType, context));
				} else {
					AnnotateMethods (baseMethod, method);
				}
			}
		}
 
		void AnnotateMethods (MethodDefinition @base, MethodDefinition @override, InterfaceImplementor? interfaceImplementor = null)
		{
			AddBaseMethod (@override, @base, interfaceImplementor);
			AddOverride (@base, @override, interfaceImplementor);
		}
 
		MethodDefinition? GetBaseMethodInTypeHierarchy (MethodDefinition method)
		{
			return GetBaseMethodInTypeHierarchy (method.DeclaringType, method);
		}
 
		MethodDefinition? GetBaseMethodInTypeHierarchy (TypeDefinition type, MethodReference method)
		{
			TypeReference? @base = GetInflatedBaseType (type);
			while (@base != null) {
				MethodDefinition? base_method = TryMatchMethod (@base, method);
				if (base_method != null)
					return base_method;
 
				@base = GetInflatedBaseType (@base);
			}
 
			return null;
		}
 
		TypeReference? GetInflatedBaseType (TypeReference type)
		{
			if (type == null)
				return null;
 
			if (type.IsGenericParameter || type.IsByReference || type.IsPointer)
				return null;
 
			if (type is SentinelType sentinelType)
				return GetInflatedBaseType (sentinelType.ElementType);
 
			if (type is PinnedType pinnedType)
				return GetInflatedBaseType (pinnedType.ElementType);
 
			if (type is RequiredModifierType requiredModifierType)
				return GetInflatedBaseType (requiredModifierType.ElementType);
 
			if (type is GenericInstanceType genericInstance) {
				var baseType = context.TryResolve (type)?.BaseType;
 
				if (baseType is GenericInstanceType)
					return TypeReferenceExtensions.InflateGenericType (genericInstance, baseType);
 
				return baseType;
			}
 
			return context.TryResolve (type)?.BaseType;
		}
 
		/// <summary>
		/// Returns a list of default implementations of the given interface method on this type.
		/// Note that this returns a list to potentially cover the diamond case (more than one
		/// most specific implementation of the given interface methods). ILLink needs to preserve
		/// all the implementations so that the proper exception can be thrown at runtime.
		/// </summary>
		/// <param name="type">The type that implements (directly or via a base interface) the declaring interface of <paramref name="interfaceMethod"/></param>
		/// <param name="interfaceMethod">The method to find a default implementation for</param>
		/// <param name="implOfInterface">
		/// The InterfaceImplementation on <paramref name="type"/> that points to the DeclaringType of <paramref name="interfaceMethod"/>.
		/// </param>
		void FindAndAddDefaultInterfaceImplementations (TypeDefinition typeThatImplementsInterface, TypeDefinition typeThatMayHaveDIM, MethodDefinition interfaceMethodToBeImplemented, InterfaceImplementation originalInterfaceImpl)
		{
			// Go over all interfaces, trying to find a method that is an explicit MethodImpl of the
			// interface method in question.
 
			foreach (var interfaceImpl in typeThatMayHaveDIM.Interfaces) {
				var potentialImplInterface = context.TryResolve (interfaceImpl.InterfaceType);
				if (potentialImplInterface == null)
					continue;
 
				bool foundImpl = false;
 
				foreach (var potentialImplMethod in potentialImplInterface.Methods) {
					if (potentialImplMethod == interfaceMethodToBeImplemented &&
						!potentialImplMethod.IsAbstract) {
						AddDefaultInterfaceImplementation (interfaceMethodToBeImplemented, new (typeThatImplementsInterface, originalInterfaceImpl, interfaceMethodToBeImplemented.DeclaringType, context), potentialImplMethod);
						foundImpl = true;
						break;
					}
 
					if (!potentialImplMethod.HasOverrides)
						continue;
 
					// This method is an override of something. Let's see if it's the method we are looking for.
					foreach (var baseMethod in potentialImplMethod.Overrides) {
						if (context.TryResolve (baseMethod) == interfaceMethodToBeImplemented) {
							AddDefaultInterfaceImplementation (interfaceMethodToBeImplemented, new (typeThatImplementsInterface, originalInterfaceImpl, interfaceMethodToBeImplemented.DeclaringType, context), @potentialImplMethod);
							foundImpl = true;
							break;
						}
					}
 
					if (foundImpl) {
						break;
					}
				}
 
				// We haven't found a MethodImpl on the current interface, but one of the interfaces
				// this interface requires could still provide it.
				if (!foundImpl) {
					FindAndAddDefaultInterfaceImplementations (typeThatImplementsInterface, potentialImplInterface, interfaceMethodToBeImplemented, originalInterfaceImpl);
				}
			}
		}
 
		MethodDefinition? TryMatchMethod (TypeReference type, MethodReference method)
		{
			foreach (var candidate in type.GetMethods (context)) {
				var md = context.TryResolve (candidate);
				if (md?.IsVirtual != true)
					continue;
 
				if (MethodMatch (candidate, method))
					return md;
			}
 
			return null;
		}
 
		[SuppressMessage ("ApiDesign", "RS0030:Do not used banned APIs", Justification = "It's best to leave working code alone.")]
		static bool MethodMatch (MethodReference candidate, MethodReference method)
		{
			if (candidate.HasParameters != method.HasMetadataParameters ())
				return false;
 
			if (candidate.Name != method.Name)
				return false;
 
			if (candidate.HasGenericParameters != method.HasGenericParameters)
				return false;
 
			// we need to track what the generic parameter represent - as we cannot allow it to
			// differ between the return type or any parameter
			if (!TypeMatch (candidate.GetReturnType (), method.GetReturnType ()))
				return false;
 
			if (!candidate.HasMetadataParameters ())
				return true;
 
			var cp = candidate.Parameters;
			var mp = method.Parameters;
			if (cp.Count != mp.Count)
				return false;
 
			if (candidate.GenericParameters.Count != method.GenericParameters.Count)
				return false;
 
			for (int i = 0; i < cp.Count; i++) {
				if (!TypeMatch (candidate.GetInflatedParameterType (i), method.GetInflatedParameterType (i)))
					return false;
			}
 
			return true;
		}
 
		static bool TypeMatch (IModifierType a, IModifierType b)
		{
			if (!TypeMatch (a.ModifierType, b.ModifierType))
				return false;
 
			return TypeMatch (a.ElementType, b.ElementType);
		}
 
		static bool TypeMatch (TypeSpecification a, TypeSpecification b)
		{
			if (a is GenericInstanceType gita)
				return TypeMatch (gita, (GenericInstanceType) b);
 
			if (a is IModifierType mta)
				return TypeMatch (mta, (IModifierType) b);
 
			if (a is FunctionPointerType fpta)
				return TypeMatch (fpta, (FunctionPointerType) b);
 
			return TypeMatch (a.ElementType, b.ElementType);
		}
 
		static bool TypeMatch (GenericInstanceType a, GenericInstanceType b)
		{
			if (!TypeMatch (a.ElementType, b.ElementType))
				return false;
 
			if (a.HasGenericArguments != b.HasGenericArguments)
				return false;
 
			if (!a.HasGenericArguments)
				return true;
 
			var gaa = a.GenericArguments;
			var gab = b.GenericArguments;
			if (gaa.Count != gab.Count)
				return false;
 
			for (int i = 0; i < gaa.Count; i++) {
				if (!TypeMatch (gaa[i], gab[i]))
					return false;
			}
 
			return true;
		}
 
		static bool TypeMatch (GenericParameter a, GenericParameter b)
		{
			if (a.Position != b.Position)
				return false;
 
			if (a.Type != b.Type)
				return false;
 
			return true;
		}
 
		static bool TypeMatch (FunctionPointerType a, FunctionPointerType b)
		{
			if (a.HasParameters != b.HasParameters)
				return false;
 
			if (a.CallingConvention != b.CallingConvention)
				return false;
 
			// we need to track what the generic parameter represent - as we cannot allow it to
			// differ between the return type or any parameter
			if (a.ReturnType is not TypeReference aReturnType ||
				b.ReturnType is not TypeReference bReturnType ||
				!TypeMatch (aReturnType, bReturnType))
				return false;
 
			if (!a.HasParameters)
				return true;
 
			var ap = a.Parameters;
			var bp = b.Parameters;
			if (ap.Count != bp.Count)
				return false;
 
			for (int i = 0; i < ap.Count; i++) {
				if (a.Parameters[i].ParameterType is not TypeReference aParameterType ||
					b.Parameters[i].ParameterType is not TypeReference bParameterType ||
					!TypeMatch (aParameterType, bParameterType))
					return false;
			}
 
			return true;
		}
 
		static bool TypeMatch (TypeReference a, TypeReference b)
		{
			if (a is TypeSpecification || b is TypeSpecification) {
				if (a.GetType () != b.GetType ())
					return false;
 
				return TypeMatch ((TypeSpecification) a, (TypeSpecification) b);
			}
 
			if (a is GenericParameter genericParameterA && b is GenericParameter genericParameterB)
				return TypeMatch (genericParameterA, genericParameterB);
 
			return a.FullName == b.FullName;
		}
	}
}