|
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
// TypeMapInfo.cs
//
// Author:
// Jb Evain (jbevain@novell.com)
//
// (C) 2009 Novell, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Mono.Cecil;
namespace Mono.Linker
{
public class TypeMapInfo
{
readonly HashSet<AssemblyDefinition> assemblies = new HashSet<AssemblyDefinition> ();
readonly LinkContext context;
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> base_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> override_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> default_interface_implementations = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
public TypeMapInfo (LinkContext context)
{
this.context = context;
}
public void EnsureProcessed (AssemblyDefinition assembly)
{
if (!assemblies.Add (assembly))
return;
foreach (TypeDefinition type in assembly.MainModule.Types)
MapType (type);
}
public ICollection<MethodDefinition> MethodsWithOverrideInformation => override_methods.Keys;
/// <summary>
/// Returns a list of all known methods that override <paramref name="method"/>. The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet
/// </summary>
public List<OverrideInformation>? GetOverrides (MethodDefinition method)
{
EnsureProcessed (method.Module.Assembly);
override_methods.TryGetValue (method, out List<OverrideInformation>? overrides);
return overrides;
}
/// <summary>
/// Returns all base methods that <paramref name="method"/> overrides.
/// This includes the closest overridden virtual method on <paramref name="method"/>'s base types
/// methods on an interface that <paramref name="method"/>'s declaring type implements,
/// and methods an interface implemented by a derived type of <paramref name="method"/>'s declaring type if the derived type uses <paramref name="method"/> as the implementing method.
/// The list may be incomplete if there are derived types in assemblies that havent been processed yet that use <paramref name="method"/> to implement an interface.
/// </summary>
public List<OverrideInformation>? GetBaseMethods (MethodDefinition method)
{
EnsureProcessed (method.Module.Assembly);
base_methods.TryGetValue (method, out List<OverrideInformation>? bases);
return bases;
}
/// <summary>
/// Returns a list of all default interface methods that implement <paramref name="method"/> for a type.
/// ImplementingType is the type that implements the interface,
/// InterfaceImpl is the <see cref="InterfaceImplementation" /> for the interface <paramref name="method" /> is declared on, and
/// DefaultInterfaceMethod is the method that implements <paramref name="method"/>.
/// </summary>
/// <param name="method">The interface method to find default implementations for</param>
public IEnumerable<OverrideInformation>? GetDefaultInterfaceImplementations (MethodDefinition baseMethod)
{
default_interface_implementations.TryGetValue (baseMethod, out var ret);
return ret;
}
public void AddBaseMethod (MethodDefinition method, MethodDefinition @base, InterfaceImplementor? interfaceImplementor)
{
base_methods.AddToList (method, new OverrideInformation (@base, method, interfaceImplementor));
}
public void AddOverride (MethodDefinition @base, MethodDefinition @override, InterfaceImplementor? interfaceImplementor = null)
{
override_methods.AddToList (@base, new OverrideInformation (@base, @override, interfaceImplementor));
}
public void AddDefaultInterfaceImplementation (MethodDefinition @base, InterfaceImplementor interfaceImplementor, MethodDefinition defaultImplementationMethod)
{
Debug.Assert(@base.DeclaringType.IsInterface);
default_interface_implementations.AddToList (@base, new OverrideInformation (@base, defaultImplementationMethod, interfaceImplementor));
}
Dictionary<TypeDefinition, List<(TypeReference, List<InterfaceImplementation>)>> interfaces = new ();
protected virtual void MapType (TypeDefinition type)
{
MapVirtualMethods (type);
MapInterfaceMethodsInTypeHierarchy (type);
interfaces[type] = GetRecursiveInterfaceImplementations (type);
if (!type.HasNestedTypes)
return;
foreach (var nested in type.NestedTypes)
MapType (nested);
}
internal List<(TypeReference InterfaceType, List<InterfaceImplementation> ImplementationChain)>? GetRecursiveInterfaces (TypeDefinition type)
{
EnsureProcessed(type.Module.Assembly);
if (interfaces.TryGetValue (type, out var value))
return value;
return null;
}
List<(TypeReference InterfaceType, List<InterfaceImplementation> ImplementationChain)> GetRecursiveInterfaceImplementations (TypeDefinition type)
{
List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain = new ();
AddRecursiveInterfaces (type, [], firstImplementationChain, context);
Debug.Assert (firstImplementationChain.All (kvp => context.Resolve (kvp.Item1) == context.Resolve (kvp.Item2.Last ().InterfaceType)));
return firstImplementationChain;
static void AddRecursiveInterfaces (TypeReference typeRef, IEnumerable<InterfaceImplementation> pathToType, List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain, LinkContext Context)
{
var type = Context.TryResolve (typeRef);
// If we can't resolve the interface type we can't find recursive interfaces
if (type is null)
return;
// Get all explicit interfaces of this type
foreach (var iface in type.Interfaces) {
var interfaceType = iface.InterfaceType.InflateFrom (typeRef as IGenericInstance);
if (!firstImplementationChain.Any (i => TypeReferenceEqualityComparer.AreEqual (i.Item1, interfaceType, Context))) {
firstImplementationChain.Add ((interfaceType, pathToType.Append (iface).ToList ()));
}
}
// Recursive interfaces after all direct interfaces to preserve Inherit/Implement tree order
foreach (var iface in type.Interfaces) {
var ifaceDirectlyOnType = iface.InterfaceType.InflateFrom (typeRef as IGenericInstance);
AddRecursiveInterfaces (ifaceDirectlyOnType, pathToType.Append (iface), firstImplementationChain, Context);
}
}
}
void MapInterfaceMethodsInTypeHierarchy (TypeDefinition type)
{
if (!type.HasInterfaces)
return;
// Foreach interface and for each newslot virtual method on the interface, try
// to find the method implementation and record it.
foreach (var interfaceImpl in type.GetInflatedInterfaces (context)) {
foreach (MethodReference interfaceMethod in interfaceImpl.InflatedInterface.GetMethods (context)) {
MethodDefinition? resolvedInterfaceMethod = context.TryResolve (interfaceMethod);
if (resolvedInterfaceMethod == null)
continue;
// TODO-NICE: if the interface method is implemented explicitly (with an override),
// we shouldn't need to run the below logic. This results in ILLink potentially
// keeping more methods than needed.
if (!resolvedInterfaceMethod.IsVirtual
|| resolvedInterfaceMethod.IsFinal)
continue;
// Static methods on interfaces must be implemented only via explicit method-impl record
// not by a signature match. So there's no point in running this logic for static methods.
if (!resolvedInterfaceMethod.IsStatic) {
// Try to find an implementation with a name/sig match on the current type
MethodDefinition? exactMatchOnType = TryMatchMethod (type, interfaceMethod);
if (exactMatchOnType != null) {
AnnotateMethods (resolvedInterfaceMethod, exactMatchOnType, new (type, interfaceImpl.OriginalImpl, resolvedInterfaceMethod.DeclaringType, context));
continue;
}
// Next try to find an implementation with a name/sig match in the base hierarchy
var @base = GetBaseMethodInTypeHierarchy (type, interfaceMethod);
if (@base != null) {
AnnotateMethods (resolvedInterfaceMethod, @base, new (type, interfaceImpl.OriginalImpl, resolvedInterfaceMethod.DeclaringType, context));
continue;
}
}
// Look for a default implementation last.
FindAndAddDefaultInterfaceImplementations (type, type, resolvedInterfaceMethod, interfaceImpl.OriginalImpl);
}
}
}
void MapVirtualMethods (TypeDefinition type)
{
if (!type.HasMethods)
return;
foreach (MethodDefinition method in type.Methods) {
// We do not proceed unless a method is virtual or is static
// A static method with a .override could be implementing a static interface method
if (!(method.IsStatic || method.IsVirtual))
continue;
if (method.IsVirtual)
MapVirtualMethod (method);
if (method.HasOverrides)
MapOverrides (method);
}
}
void MapVirtualMethod (MethodDefinition method)
{
MethodDefinition? @base = GetBaseMethodInTypeHierarchy (method);
if (@base == null)
return;
Debug.Assert(!@base.DeclaringType.IsInterface);
AnnotateMethods (@base, method);
}
void MapOverrides (MethodDefinition method)
{
foreach (MethodReference baseMethodRef in method.Overrides) {
MethodDefinition? baseMethod = context.TryResolve (baseMethodRef);
if (baseMethod == null)
continue;
if (baseMethod.DeclaringType.IsInterface) {
AnnotateMethods (baseMethod, method, InterfaceImplementor.Create (method.DeclaringType, baseMethod.DeclaringType, context));
} else {
AnnotateMethods (baseMethod, method);
}
}
}
void AnnotateMethods (MethodDefinition @base, MethodDefinition @override, InterfaceImplementor? interfaceImplementor = null)
{
AddBaseMethod (@override, @base, interfaceImplementor);
AddOverride (@base, @override, interfaceImplementor);
}
MethodDefinition? GetBaseMethodInTypeHierarchy (MethodDefinition method)
{
return GetBaseMethodInTypeHierarchy (method.DeclaringType, method);
}
MethodDefinition? GetBaseMethodInTypeHierarchy (TypeDefinition type, MethodReference method)
{
TypeReference? @base = GetInflatedBaseType (type);
while (@base != null) {
MethodDefinition? base_method = TryMatchMethod (@base, method);
if (base_method != null)
return base_method;
@base = GetInflatedBaseType (@base);
}
return null;
}
TypeReference? GetInflatedBaseType (TypeReference type)
{
if (type == null)
return null;
if (type.IsGenericParameter || type.IsByReference || type.IsPointer)
return null;
if (type is SentinelType sentinelType)
return GetInflatedBaseType (sentinelType.ElementType);
if (type is PinnedType pinnedType)
return GetInflatedBaseType (pinnedType.ElementType);
if (type is RequiredModifierType requiredModifierType)
return GetInflatedBaseType (requiredModifierType.ElementType);
if (type is GenericInstanceType genericInstance) {
var baseType = context.TryResolve (type)?.BaseType;
if (baseType is GenericInstanceType)
return TypeReferenceExtensions.InflateGenericType (genericInstance, baseType);
return baseType;
}
return context.TryResolve (type)?.BaseType;
}
/// <summary>
/// Returns a list of default implementations of the given interface method on this type.
/// Note that this returns a list to potentially cover the diamond case (more than one
/// most specific implementation of the given interface methods). ILLink needs to preserve
/// all the implementations so that the proper exception can be thrown at runtime.
/// </summary>
/// <param name="type">The type that implements (directly or via a base interface) the declaring interface of <paramref name="interfaceMethod"/></param>
/// <param name="interfaceMethod">The method to find a default implementation for</param>
/// <param name="implOfInterface">
/// The InterfaceImplementation on <paramref name="type"/> that points to the DeclaringType of <paramref name="interfaceMethod"/>.
/// </param>
void FindAndAddDefaultInterfaceImplementations (TypeDefinition typeThatImplementsInterface, TypeDefinition typeThatMayHaveDIM, MethodDefinition interfaceMethodToBeImplemented, InterfaceImplementation originalInterfaceImpl)
{
// Go over all interfaces, trying to find a method that is an explicit MethodImpl of the
// interface method in question.
foreach (var interfaceImpl in typeThatMayHaveDIM.Interfaces) {
var potentialImplInterface = context.TryResolve (interfaceImpl.InterfaceType);
if (potentialImplInterface == null)
continue;
bool foundImpl = false;
foreach (var potentialImplMethod in potentialImplInterface.Methods) {
if (potentialImplMethod == interfaceMethodToBeImplemented &&
!potentialImplMethod.IsAbstract) {
AddDefaultInterfaceImplementation (interfaceMethodToBeImplemented, new (typeThatImplementsInterface, originalInterfaceImpl, interfaceMethodToBeImplemented.DeclaringType, context), potentialImplMethod);
foundImpl = true;
break;
}
if (!potentialImplMethod.HasOverrides)
continue;
// This method is an override of something. Let's see if it's the method we are looking for.
foreach (var baseMethod in potentialImplMethod.Overrides) {
if (context.TryResolve (baseMethod) == interfaceMethodToBeImplemented) {
AddDefaultInterfaceImplementation (interfaceMethodToBeImplemented, new (typeThatImplementsInterface, originalInterfaceImpl, interfaceMethodToBeImplemented.DeclaringType, context), @potentialImplMethod);
foundImpl = true;
break;
}
}
if (foundImpl) {
break;
}
}
// We haven't found a MethodImpl on the current interface, but one of the interfaces
// this interface requires could still provide it.
if (!foundImpl) {
FindAndAddDefaultInterfaceImplementations (typeThatImplementsInterface, potentialImplInterface, interfaceMethodToBeImplemented, originalInterfaceImpl);
}
}
}
MethodDefinition? TryMatchMethod (TypeReference type, MethodReference method)
{
foreach (var candidate in type.GetMethods (context)) {
var md = context.TryResolve (candidate);
if (md?.IsVirtual != true)
continue;
if (MethodMatch (candidate, method))
return md;
}
return null;
}
[SuppressMessage ("ApiDesign", "RS0030:Do not used banned APIs", Justification = "It's best to leave working code alone.")]
static bool MethodMatch (MethodReference candidate, MethodReference method)
{
if (candidate.HasParameters != method.HasMetadataParameters ())
return false;
if (candidate.Name != method.Name)
return false;
if (candidate.HasGenericParameters != method.HasGenericParameters)
return false;
// we need to track what the generic parameter represent - as we cannot allow it to
// differ between the return type or any parameter
if (!TypeMatch (candidate.GetReturnType (), method.GetReturnType ()))
return false;
if (!candidate.HasMetadataParameters ())
return true;
var cp = candidate.Parameters;
var mp = method.Parameters;
if (cp.Count != mp.Count)
return false;
if (candidate.GenericParameters.Count != method.GenericParameters.Count)
return false;
for (int i = 0; i < cp.Count; i++) {
if (!TypeMatch (candidate.GetInflatedParameterType (i), method.GetInflatedParameterType (i)))
return false;
}
return true;
}
static bool TypeMatch (IModifierType a, IModifierType b)
{
if (!TypeMatch (a.ModifierType, b.ModifierType))
return false;
return TypeMatch (a.ElementType, b.ElementType);
}
static bool TypeMatch (TypeSpecification a, TypeSpecification b)
{
if (a is GenericInstanceType gita)
return TypeMatch (gita, (GenericInstanceType) b);
if (a is IModifierType mta)
return TypeMatch (mta, (IModifierType) b);
if (a is FunctionPointerType fpta)
return TypeMatch (fpta, (FunctionPointerType) b);
return TypeMatch (a.ElementType, b.ElementType);
}
static bool TypeMatch (GenericInstanceType a, GenericInstanceType b)
{
if (!TypeMatch (a.ElementType, b.ElementType))
return false;
if (a.HasGenericArguments != b.HasGenericArguments)
return false;
if (!a.HasGenericArguments)
return true;
var gaa = a.GenericArguments;
var gab = b.GenericArguments;
if (gaa.Count != gab.Count)
return false;
for (int i = 0; i < gaa.Count; i++) {
if (!TypeMatch (gaa[i], gab[i]))
return false;
}
return true;
}
static bool TypeMatch (GenericParameter a, GenericParameter b)
{
if (a.Position != b.Position)
return false;
if (a.Type != b.Type)
return false;
return true;
}
static bool TypeMatch (FunctionPointerType a, FunctionPointerType b)
{
if (a.HasParameters != b.HasParameters)
return false;
if (a.CallingConvention != b.CallingConvention)
return false;
// we need to track what the generic parameter represent - as we cannot allow it to
// differ between the return type or any parameter
if (a.ReturnType is not TypeReference aReturnType ||
b.ReturnType is not TypeReference bReturnType ||
!TypeMatch (aReturnType, bReturnType))
return false;
if (!a.HasParameters)
return true;
var ap = a.Parameters;
var bp = b.Parameters;
if (ap.Count != bp.Count)
return false;
for (int i = 0; i < ap.Count; i++) {
if (a.Parameters[i].ParameterType is not TypeReference aParameterType ||
b.Parameters[i].ParameterType is not TypeReference bParameterType ||
!TypeMatch (aParameterType, bParameterType))
return false;
}
return true;
}
static bool TypeMatch (TypeReference a, TypeReference b)
{
if (a is TypeSpecification || b is TypeSpecification) {
if (a.GetType () != b.GetType ())
return false;
return TypeMatch ((TypeSpecification) a, (TypeSpecification) b);
}
if (a is GenericParameter genericParameterA && b is GenericParameter genericParameterB)
return TypeMatch (genericParameterA, genericParameterB);
return a.FullName == b.FullName;
}
}
}
|