|
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System.Collections.Generic;
using System.Diagnostics;
using ILLink.Shared.TypeSystemProxy;
using Mono.Cecil;
namespace Mono.Linker.Steps
{
public class DiscoverOperatorsHandler : IMarkHandler
{
LinkContext? _context;
LinkContext Context
{
get
{
Debug.Assert(_context != null);
return _context;
}
}
bool _seenLinqExpressions;
readonly HashSet<TypeDefinition> _trackedTypesWithOperators;
Dictionary<TypeDefinition, List<MethodDefinition>>? _pendingOperatorsForType;
Dictionary<TypeDefinition, List<MethodDefinition>> PendingOperatorsForType
{
get
{
_pendingOperatorsForType ??= new Dictionary<TypeDefinition, List<MethodDefinition>>();
return _pendingOperatorsForType;
}
}
public DiscoverOperatorsHandler()
{
_trackedTypesWithOperators = new HashSet<TypeDefinition>();
}
public void Initialize(LinkContext context, MarkContext markContext)
{
_context = context;
markContext.RegisterMarkTypeAction(ProcessType);
}
void ProcessType(TypeDefinition type)
{
CheckForLinqExpressions(type);
// Check for custom operators and either:
// - mark them, if Linq.Expressions was already marked, or
// - track them to be marked in case Linq.Expressions is marked later
var hasOperators = ProcessCustomOperators(type, mark: _seenLinqExpressions);
if (!_seenLinqExpressions)
{
if (hasOperators)
_trackedTypesWithOperators.Add(type);
return;
}
// Mark pending operators defined on other types that reference this type
// (these are only tracked if we have already seen Linq.Expressions)
if (PendingOperatorsForType.TryGetValue(type, out var pendingOperators))
{
foreach (var customOperator in pendingOperators)
MarkOperator(customOperator);
PendingOperatorsForType.Remove(type);
}
}
void CheckForLinqExpressions(TypeDefinition type)
{
if (_seenLinqExpressions)
return;
if (type.Namespace != "System.Linq.Expressions" || type.Name != "Expression")
return;
_seenLinqExpressions = true;
foreach (var markedType in _trackedTypesWithOperators)
ProcessCustomOperators(markedType, mark: true);
_trackedTypesWithOperators.Clear();
}
void MarkOperator(MethodDefinition method)
{
Context.Annotations.Mark(method, new DependencyInfo(DependencyKind.PreservedOperator, method.DeclaringType), new MessageOrigin(method.DeclaringType));
}
bool ProcessCustomOperators(TypeDefinition type, bool mark)
{
if (!type.HasMethods)
return false;
bool hasCustomOperators = false;
foreach (var method in type.Methods)
{
if (!IsOperator(method, out var otherType))
continue;
if (!mark)
return true;
Debug.Assert(_seenLinqExpressions);
hasCustomOperators = true;
if (otherType == null || Context.Annotations.IsMarked(otherType))
{
MarkOperator(method);
continue;
}
// Wait until otherType gets marked to mark the operator.
if (!PendingOperatorsForType.TryGetValue(otherType, out var pendingOperators))
{
pendingOperators = new List<MethodDefinition>();
PendingOperatorsForType.Add(otherType, pendingOperators);
}
pendingOperators.Add(method);
}
return hasCustomOperators;
}
TypeDefinition? _nullableOfT;
TypeDefinition? NullableOfT
{
get
{
_nullableOfT ??= BCL.FindPredefinedType(WellKnownType.System_Nullable_T, Context);
return _nullableOfT;
}
}
TypeDefinition? NonNullableType(TypeReference type)
{
var typeDef = Context.TryResolve(type);
if (typeDef == null)
return null;
if (!typeDef.IsValueType || typeDef != NullableOfT)
return typeDef;
// Unwrap Nullable<T>
Debug.Assert(typeDef.HasGenericParameters);
// The original type reference might be a TypeSpecification like array of Nullable<T>
// that we need to unwrap until we get to the Nullable<T>
while (!type.IsGenericInstance)
type = ((TypeSpecification)type).ElementType;
var nullableType = type as GenericInstanceType;
Debug.Assert(nullableType != null && nullableType.HasGenericArguments && nullableType.GenericArguments.Count == 1);
return Context.TryResolve(nullableType.GenericArguments[0]);
}
bool IsOperator(MethodDefinition method, out TypeDefinition? otherType)
{
otherType = null;
if (!method.IsStatic || !method.IsPublic || !method.IsSpecialName || !method.Name.StartsWith("op_"))
return false;
var operatorName = method.Name.Substring(3);
var self = method.DeclaringType;
switch (operatorName)
{
// Unary operators
case "UnaryPlus":
case "UnaryNegation":
case "LogicalNot":
case "OnesComplement":
case "Increment":
case "Decrement":
case "True":
case "False":
// Parameter type of a unary operator must be the declaring type
if (method.GetMetadataParametersCount() != 1 || NonNullableType(method.GetParameter((ParameterIndex)0).ParameterType) != self)
return false;
// ++ and -- must return the declaring type
if (operatorName is "Increment" or "Decrement" && NonNullableType(method.ReturnType) != self)
return false;
return true;
// Binary operators
case "Addition":
case "Subtraction":
case "Multiply":
case "Division":
case "Modulus":
case "BitwiseAnd":
case "BitwiseOr":
case "ExclusiveOr":
case "LeftShift":
case "RightShift":
case "Equality":
case "Inequality":
case "LessThan":
case "GreaterThan":
case "LessThanOrEqual":
case "GreaterThanOrEqual":
if (method.GetMetadataParametersCount() != 2)
return false;
var nnLeft = NonNullableType(method.GetParameter((ParameterIndex)0).ParameterType);
var nnRight = NonNullableType(method.GetParameter((ParameterIndex)1).ParameterType);
if (nnLeft == null || nnRight == null)
return false;
// << and >> must take the declaring type and int
if (operatorName is "LeftShift" or "RightShift" && (nnLeft != self || nnRight.MetadataType != MetadataType.Int32))
return false;
// At least one argument must be the declaring type
if (nnLeft != self && nnRight != self)
return false;
if (nnLeft != self)
otherType = nnLeft;
if (nnRight != self)
otherType = nnRight;
return true;
// Conversion operators
case "Implicit":
case "Explicit":
if (method.GetMetadataParametersCount() != 1)
return false;
var nnSource = NonNullableType(method.GetParameter((ParameterIndex)0).ParameterType);
var nnTarget = NonNullableType(method.ReturnType);
// Exactly one of source/target must be the declaring type
if (nnSource == self == (nnTarget == self))
return false;
otherType = nnSource == self ? nnTarget : nnSource;
return true;
default:
return false;
}
}
}
}
|