File: Linker.Steps\DiscoverCustomOperatorsHandler.cs
Web Access
Project: src\src\tools\illink\src\linker\Mono.Linker.csproj (illink)
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
 
using System.Collections.Generic;
using System.Diagnostics;
using ILLink.Shared.TypeSystemProxy;
using Mono.Cecil;
 
namespace Mono.Linker.Steps
{
	public class DiscoverOperatorsHandler : IMarkHandler
	{
		LinkContext? _context;
		LinkContext Context {
			get {
				Debug.Assert (_context != null);
				return _context;
			}
		}
 
		bool _seenLinqExpressions;
		readonly HashSet<TypeDefinition> _trackedTypesWithOperators;
		Dictionary<TypeDefinition, List<MethodDefinition>>? _pendingOperatorsForType;
 
		Dictionary<TypeDefinition, List<MethodDefinition>> PendingOperatorsForType {
			get {
				_pendingOperatorsForType ??= new Dictionary<TypeDefinition, List<MethodDefinition>> ();
				return _pendingOperatorsForType;
			}
		}
 
		public DiscoverOperatorsHandler ()
		{
			_trackedTypesWithOperators = new HashSet<TypeDefinition> ();
		}
 
		public void Initialize (LinkContext context, MarkContext markContext)
		{
			_context = context;
			markContext.RegisterMarkTypeAction (ProcessType);
		}
 
		void ProcessType (TypeDefinition type)
		{
			CheckForLinqExpressions (type);
 
			// Check for custom operators and either:
			// - mark them, if Linq.Expressions was already marked, or
			// - track them to be marked in case Linq.Expressions is marked later
			var hasOperators = ProcessCustomOperators (type, mark: _seenLinqExpressions);
			if (!_seenLinqExpressions) {
				if (hasOperators)
					_trackedTypesWithOperators.Add (type);
				return;
			}
 
			// Mark pending operators defined on other types that reference this type
			// (these are only tracked if we have already seen Linq.Expressions)
			if (PendingOperatorsForType.TryGetValue (type, out var pendingOperators)) {
				foreach (var customOperator in pendingOperators)
					MarkOperator (customOperator);
				PendingOperatorsForType.Remove (type);
			}
		}
 
		void CheckForLinqExpressions (TypeDefinition type)
		{
			if (_seenLinqExpressions)
				return;
 
			if (type.Namespace != "System.Linq.Expressions" || type.Name != "Expression")
				return;
 
			_seenLinqExpressions = true;
 
			foreach (var markedType in _trackedTypesWithOperators)
				ProcessCustomOperators (markedType, mark: true);
 
			_trackedTypesWithOperators.Clear ();
		}
 
		void MarkOperator (MethodDefinition method)
		{
			Context.Annotations.Mark (method, new DependencyInfo (DependencyKind.PreservedOperator, method.DeclaringType), new MessageOrigin (method.DeclaringType));
		}
 
		bool ProcessCustomOperators (TypeDefinition type, bool mark)
		{
			if (!type.HasMethods)
				return false;
 
			bool hasCustomOperators = false;
			foreach (var method in type.Methods) {
				if (!IsOperator (method, out var otherType))
					continue;
 
				if (!mark)
					return true;
 
				Debug.Assert (_seenLinqExpressions);
				hasCustomOperators = true;
 
				if (otherType == null || Context.Annotations.IsMarked (otherType)) {
					MarkOperator (method);
					continue;
				}
 
				// Wait until otherType gets marked to mark the operator.
				if (!PendingOperatorsForType.TryGetValue (otherType, out var pendingOperators)) {
					pendingOperators = new List<MethodDefinition> ();
					PendingOperatorsForType.Add (otherType, pendingOperators);
				}
				pendingOperators.Add (method);
			}
			return hasCustomOperators;
		}
 
		TypeDefinition? _nullableOfT;
		TypeDefinition? NullableOfT {
			get {
				_nullableOfT ??= BCL.FindPredefinedType (WellKnownType.System_Nullable_T, Context);
				return _nullableOfT;
			}
		}
 
		TypeDefinition? NonNullableType (TypeReference type)
		{
			var typeDef = Context.TryResolve (type);
			if (typeDef == null)
				return null;
 
			if (!typeDef.IsValueType || typeDef != NullableOfT)
				return typeDef;
 
			// Unwrap Nullable<T>
			Debug.Assert (typeDef.HasGenericParameters);
			// The original type reference might be a TypeSpecification like array of Nullable<T>
			// that we need to unwrap until we get to the Nullable<T>
			while (!type.IsGenericInstance)
				type = ((TypeSpecification) type).ElementType;
			var nullableType = type as GenericInstanceType;
			Debug.Assert (nullableType != null && nullableType.HasGenericArguments && nullableType.GenericArguments.Count == 1);
			return Context.TryResolve (nullableType.GenericArguments[0]);
		}
 
		bool IsOperator (MethodDefinition method, out TypeDefinition? otherType)
		{
			otherType = null;
 
			if (!method.IsStatic || !method.IsPublic || !method.IsSpecialName || !method.Name.StartsWith ("op_"))
				return false;
 
			var operatorName = method.Name.Substring (3);
			var self = method.DeclaringType;
 
			switch (operatorName) {
			// Unary operators
			case "UnaryPlus":
			case "UnaryNegation":
			case "LogicalNot":
			case "OnesComplement":
			case "Increment":
			case "Decrement":
			case "True":
			case "False":
				// Parameter type of a unary operator must be the declaring type
				if (method.GetMetadataParametersCount () != 1 || NonNullableType (method.GetParameter ((ParameterIndex) 0).ParameterType) != self)
					return false;
				// ++ and -- must return the declaring type
				if (operatorName is "Increment" or "Decrement" && NonNullableType (method.ReturnType) != self)
					return false;
				return true;
			// Binary operators
			case "Addition":
			case "Subtraction":
			case "Multiply":
			case "Division":
			case "Modulus":
			case "BitwiseAnd":
			case "BitwiseOr":
			case "ExclusiveOr":
			case "LeftShift":
			case "RightShift":
			case "Equality":
			case "Inequality":
			case "LessThan":
			case "GreaterThan":
			case "LessThanOrEqual":
			case "GreaterThanOrEqual":
				if (method.GetMetadataParametersCount () != 2)
					return false;
				var nnLeft = NonNullableType (method.GetParameter ((ParameterIndex) 0).ParameterType);
				var nnRight = NonNullableType (method.GetParameter ((ParameterIndex) 1).ParameterType);
				if (nnLeft == null || nnRight == null)
					return false;
				// << and >> must take the declaring type and int
				if (operatorName is "LeftShift" or "RightShift" && (nnLeft != self || nnRight.MetadataType != MetadataType.Int32))
					return false;
				// At least one argument must be the declaring type
				if (nnLeft != self && nnRight != self)
					return false;
				if (nnLeft != self)
					otherType = nnLeft;
				if (nnRight != self)
					otherType = nnRight;
				return true;
			// Conversion operators
			case "Implicit":
			case "Explicit":
				if (method.GetMetadataParametersCount () != 1)
					return false;
				var nnSource = NonNullableType (method.GetParameter ((ParameterIndex) 0).ParameterType);
				var nnTarget = NonNullableType (method.ReturnType);
				// Exactly one of source/target must be the declaring type
				if (nnSource == self == (nnTarget == self))
					return false;
				otherType = nnSource == self ? nnTarget : nnSource;
				return true;
			default:
				return false;
			}
		}
	}
}