|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Diagnostics;
using System.Dynamic.Utils;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.CompilerServices;
using static System.Linq.Expressions.CachedReflectionInfo;
namespace System.Linq.Expressions.Compiler
{
internal sealed partial class LambdaCompiler
{
private void EmitBinaryExpression(Expression expr)
{
EmitBinaryExpression(expr, CompilationFlags.EmitAsNoTail);
}
private void EmitBinaryExpression(Expression expr, CompilationFlags flags)
{
BinaryExpression b = (BinaryExpression)expr;
Debug.Assert(b.NodeType != ExpressionType.AndAlso && b.NodeType != ExpressionType.OrElse && b.NodeType != ExpressionType.Coalesce);
if (b.Method != null)
{
EmitBinaryMethod(b, flags);
return;
}
// For EQ and NE, if there is a user-specified method, use it.
// Otherwise implement the C# semantics that allow equality
// comparisons on non-primitive nullable structs that don't
// overload "=="
if ((b.NodeType == ExpressionType.Equal || b.NodeType == ExpressionType.NotEqual) &&
(b.Type == typeof(bool) || b.Type == typeof(bool?)))
{
// If we have x==null, x!=null, null==x or null!=x where x is
// nullable but not null, then generate a call to x.HasValue.
Debug.Assert(!b.IsLiftedToNull || b.Type == typeof(bool?));
if (ConstantCheck.IsNull(b.Left) && !ConstantCheck.IsNull(b.Right) && b.Right.Type.IsNullableType())
{
EmitNullEquality(b.NodeType, b.Right, b.IsLiftedToNull);
return;
}
if (ConstantCheck.IsNull(b.Right) && !ConstantCheck.IsNull(b.Left) && b.Left.Type.IsNullableType())
{
EmitNullEquality(b.NodeType, b.Left, b.IsLiftedToNull);
return;
}
// For EQ and NE, we can avoid some conversions if we're
// ultimately just comparing two managed pointers.
EmitExpression(GetEqualityOperand(b.Left));
EmitExpression(GetEqualityOperand(b.Right));
}
else
{
// Otherwise generate it normally
EmitExpression(b.Left);
EmitExpression(b.Right);
}
EmitBinaryOperator(b.NodeType, b.Left.Type, b.Right.Type, b.Type, b.IsLiftedToNull);
}
private void EmitNullEquality(ExpressionType op, Expression e, bool isLiftedToNull)
{
Debug.Assert(e.Type.IsNullableType());
Debug.Assert(op == ExpressionType.Equal || op == ExpressionType.NotEqual);
// If we are lifted to null then just evaluate the expression for its side effects, discard,
// and generate null. If we are not lifted to null then generate a call to HasValue.
if (isLiftedToNull)
{
EmitExpressionAsVoid(e);
_ilg.EmitDefault(typeof(bool?), this);
}
else
{
EmitAddress(e, e.Type);
_ilg.EmitHasValue(e.Type);
if (op == ExpressionType.Equal)
{
_ilg.Emit(OpCodes.Ldc_I4_0);
_ilg.Emit(OpCodes.Ceq);
}
}
}
private void EmitBinaryMethod(BinaryExpression b, CompilationFlags flags)
{
if (b.IsLifted)
{
ParameterExpression p1 = Expression.Variable(b.Left.Type.GetNonNullableType(), name: null);
ParameterExpression p2 = Expression.Variable(b.Right.Type.GetNonNullableType(), name: null);
MethodCallExpression mc = Expression.Call(null, b.Method!, p1, p2);
Type resultType;
if (b.IsLiftedToNull)
{
resultType = mc.Type.LiftPrimitiveOrThrow();
}
else
{
Debug.Assert(mc.Type == typeof(bool));
Debug.Assert(b.NodeType == ExpressionType.Equal
|| b.NodeType == ExpressionType.NotEqual
|| b.NodeType == ExpressionType.LessThan
|| b.NodeType == ExpressionType.LessThanOrEqual
|| b.NodeType == ExpressionType.GreaterThan
|| b.NodeType == ExpressionType.GreaterThanOrEqual);
resultType = typeof(bool);
}
Debug.Assert(TypeUtils.AreReferenceAssignable(p1.Type, b.Left.Type.GetNonNullableType()));
Debug.Assert(TypeUtils.AreReferenceAssignable(p2.Type, b.Right.Type.GetNonNullableType()));
EmitLift(b.NodeType, resultType, mc, new[] { p1, p2 }, new[] { b.Left, b.Right });
}
else
{
EmitMethodCallExpression(Expression.Call(null, b.Method!, b.Left, b.Right), flags);
}
}
private void EmitBinaryOperator(ExpressionType op, Type leftType, Type rightType, Type resultType, bool liftedToNull)
{
Debug.Assert(op != ExpressionType.Coalesce);
if (op == ExpressionType.ArrayIndex)
{
Debug.Assert(rightType == typeof(int));
EmitGetArrayElement(leftType);
}
else if (leftType.IsNullableType() || rightType.IsNullableType())
{
EmitLiftedBinaryOp(op, leftType, rightType, resultType, liftedToNull);
}
else
{
EmitUnliftedBinaryOp(op, leftType, rightType);
}
}
private void EmitUnliftedBinaryOp(ExpressionType op, Type leftType, Type rightType)
{
Debug.Assert(!leftType.IsNullableType());
Debug.Assert(!rightType.IsNullableType());
Debug.Assert(leftType.IsPrimitive || (op == ExpressionType.Equal || op == ExpressionType.NotEqual) && (!leftType.IsValueType || leftType.IsEnum));
switch (op)
{
case ExpressionType.NotEqual:
if (leftType.GetTypeCode() == TypeCode.Boolean)
{
goto case ExpressionType.ExclusiveOr;
}
_ilg.Emit(OpCodes.Ceq);
_ilg.Emit(OpCodes.Ldc_I4_0);
goto case ExpressionType.Equal;
case ExpressionType.Equal:
_ilg.Emit(OpCodes.Ceq);
return;
case ExpressionType.Add:
_ilg.Emit(OpCodes.Add);
break;
case ExpressionType.AddChecked:
_ilg.Emit(leftType.IsFloatingPoint() ? OpCodes.Add : (leftType.IsUnsigned() ? OpCodes.Add_Ovf_Un : OpCodes.Add_Ovf));
break;
case ExpressionType.Subtract:
_ilg.Emit(OpCodes.Sub);
break;
case ExpressionType.SubtractChecked:
if (leftType.IsUnsigned())
{
_ilg.Emit(OpCodes.Sub_Ovf_Un);
// Guaranteed to fit within result type: no conversion
return;
}
else
{
_ilg.Emit(leftType.IsFloatingPoint() ? OpCodes.Sub : OpCodes.Sub_Ovf);
}
break;
case ExpressionType.Multiply:
_ilg.Emit(OpCodes.Mul);
break;
case ExpressionType.MultiplyChecked:
_ilg.Emit(leftType.IsFloatingPoint() ? OpCodes.Mul : (leftType.IsUnsigned() ? OpCodes.Mul_Ovf_Un : OpCodes.Mul_Ovf));
break;
case ExpressionType.Divide:
_ilg.Emit(leftType.IsUnsigned() ? OpCodes.Div_Un : OpCodes.Div);
break;
case ExpressionType.Modulo:
_ilg.Emit(leftType.IsUnsigned() ? OpCodes.Rem_Un : OpCodes.Rem);
// Guaranteed to fit within result type: no conversion
return;
case ExpressionType.And:
case ExpressionType.AndAlso:
_ilg.Emit(OpCodes.And);
// Not an arithmetic operation: no conversion
return;
case ExpressionType.Or:
case ExpressionType.OrElse:
_ilg.Emit(OpCodes.Or);
// Not an arithmetic operation: no conversion
return;
case ExpressionType.LessThan:
_ilg.Emit(leftType.IsUnsigned() ? OpCodes.Clt_Un : OpCodes.Clt);
// Not an arithmetic operation: no conversion
return;
case ExpressionType.LessThanOrEqual:
_ilg.Emit(leftType.IsUnsigned() || leftType.IsFloatingPoint() ? OpCodes.Cgt_Un : OpCodes.Cgt);
_ilg.Emit(OpCodes.Ldc_I4_0);
_ilg.Emit(OpCodes.Ceq);
// Not an arithmetic operation: no conversion
return;
case ExpressionType.GreaterThan:
_ilg.Emit(leftType.IsUnsigned() ? OpCodes.Cgt_Un : OpCodes.Cgt);
// Not an arithmetic operation: no conversion
return;
case ExpressionType.GreaterThanOrEqual:
_ilg.Emit(leftType.IsUnsigned() || leftType.IsFloatingPoint() ? OpCodes.Clt_Un : OpCodes.Clt);
_ilg.Emit(OpCodes.Ldc_I4_0);
_ilg.Emit(OpCodes.Ceq);
// Not an arithmetic operation: no conversion
return;
case ExpressionType.ExclusiveOr:
_ilg.Emit(OpCodes.Xor);
// Not an arithmetic operation: no conversion
return;
case ExpressionType.LeftShift:
Debug.Assert(rightType == typeof(int));
EmitShiftMask(leftType);
_ilg.Emit(OpCodes.Shl);
break;
case ExpressionType.RightShift:
Debug.Assert(rightType == typeof(int));
EmitShiftMask(leftType);
_ilg.Emit(leftType.IsUnsigned() ? OpCodes.Shr_Un : OpCodes.Shr);
// Guaranteed to fit within result type: no conversion
return;
}
EmitConvertArithmeticResult(op, leftType);
}
// Shift operations have undefined behavior if the shift amount exceeds
// the number of bits in the value operand. See CLI III.3.58 and C# 7.9
// for the bit mask used below.
private void EmitShiftMask(Type leftType)
{
int mask = leftType.IsInteger64() ? 0x3F : 0x1F;
_ilg.EmitPrimitive(mask);
_ilg.Emit(OpCodes.And);
}
// Binary/unary operations on 8 and 16 bit operand types will leave a
// 32-bit value on the stack, because that's how IL works. For these
// cases, we need to cast it back to the resultType, possibly using a
// checked conversion if the original operator was convert
private void EmitConvertArithmeticResult(ExpressionType op, Type resultType)
{
Debug.Assert(!resultType.IsNullableType());
switch (resultType.GetTypeCode())
{
case TypeCode.Byte:
_ilg.Emit(IsChecked(op) ? OpCodes.Conv_Ovf_U1 : OpCodes.Conv_U1);
break;
case TypeCode.SByte:
_ilg.Emit(IsChecked(op) ? OpCodes.Conv_Ovf_I1 : OpCodes.Conv_I1);
break;
case TypeCode.UInt16:
_ilg.Emit(IsChecked(op) ? OpCodes.Conv_Ovf_U2 : OpCodes.Conv_U2);
break;
case TypeCode.Int16:
_ilg.Emit(IsChecked(op) ? OpCodes.Conv_Ovf_I2 : OpCodes.Conv_I2);
break;
}
}
private void EmitLiftedBinaryOp(ExpressionType op, Type leftType, Type rightType, Type resultType, bool liftedToNull)
{
Debug.Assert(leftType.IsNullableType() || rightType.IsNullableType());
switch (op)
{
case ExpressionType.And:
if (leftType == typeof(bool?))
{
EmitLiftedBooleanAnd();
}
else
{
EmitLiftedBinaryArithmetic(op, leftType, rightType, resultType);
}
break;
case ExpressionType.Or:
if (leftType == typeof(bool?))
{
EmitLiftedBooleanOr();
}
else
{
EmitLiftedBinaryArithmetic(op, leftType, rightType, resultType);
}
break;
case ExpressionType.ExclusiveOr:
case ExpressionType.Add:
case ExpressionType.AddChecked:
case ExpressionType.Subtract:
case ExpressionType.SubtractChecked:
case ExpressionType.Multiply:
case ExpressionType.MultiplyChecked:
case ExpressionType.Divide:
case ExpressionType.Modulo:
case ExpressionType.LeftShift:
case ExpressionType.RightShift:
EmitLiftedBinaryArithmetic(op, leftType, rightType, resultType);
break;
case ExpressionType.LessThan:
case ExpressionType.LessThanOrEqual:
case ExpressionType.GreaterThan:
case ExpressionType.GreaterThanOrEqual:
case ExpressionType.Equal:
case ExpressionType.NotEqual:
Debug.Assert(leftType == rightType);
if (liftedToNull)
{
Debug.Assert(resultType == typeof(bool?));
EmitLiftedToNullRelational(op, leftType);
}
else
{
Debug.Assert(resultType == typeof(bool));
EmitLiftedRelational(op, leftType);
}
break;
}
}
private void EmitLiftedRelational(ExpressionType op, Type type)
{
// Equal is (left.GetValueOrDefault() == right.GetValueOrDefault()) & (left.HasValue == right.HasValue)
// NotEqual is !((left.GetValueOrDefault() == right.GetValueOrDefault()) & (left.HasValue == right.HasValue))
// Others are (left.GetValueOrDefault() op right.GetValueOrDefault()) & (left.HasValue & right.HasValue)
bool invert = op == ExpressionType.NotEqual;
if (invert)
{
op = ExpressionType.Equal;
}
LocalBuilder locLeft = GetLocal(type);
LocalBuilder locRight = GetLocal(type);
_ilg.Emit(OpCodes.Stloc, locRight);
_ilg.Emit(OpCodes.Stloc, locLeft);
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitGetValueOrDefault(type);
_ilg.Emit(OpCodes.Ldloca, locRight);
_ilg.EmitGetValueOrDefault(type);
Type unnullable = type.GetNonNullableType();
EmitUnliftedBinaryOp(op, unnullable, unnullable);
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitHasValue(type);
_ilg.Emit(OpCodes.Ldloca, locRight);
_ilg.EmitHasValue(type);
FreeLocal(locLeft);
FreeLocal(locRight);
_ilg.Emit(op == ExpressionType.Equal ? OpCodes.Ceq : OpCodes.And);
_ilg.Emit(OpCodes.And);
if (invert)
{
_ilg.Emit(OpCodes.Ldc_I4_0);
_ilg.Emit(OpCodes.Ceq);
}
}
private void EmitLiftedToNullRelational(ExpressionType op, Type type)
{
// (left.HasValue & right.HasValue) ? left.GetValueOrDefault() op right.GetValueOrDefault() : default(bool?)
Label notNull = _ilg.DefineLabel();
Label end = _ilg.DefineLabel();
LocalBuilder locLeft = GetLocal(type);
LocalBuilder locRight = GetLocal(type);
_ilg.Emit(OpCodes.Stloc, locRight);
_ilg.Emit(OpCodes.Stloc, locLeft);
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitHasValue(type);
_ilg.Emit(OpCodes.Ldloca, locRight);
_ilg.EmitHasValue(type);
_ilg.Emit(OpCodes.And);
_ilg.Emit(OpCodes.Brtrue_S, notNull);
_ilg.EmitDefault(typeof(bool?), this);
_ilg.Emit(OpCodes.Br_S, end);
_ilg.MarkLabel(notNull);
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitGetValueOrDefault(type);
_ilg.Emit(OpCodes.Ldloca, locRight);
_ilg.EmitGetValueOrDefault(type);
FreeLocal(locLeft);
FreeLocal(locRight);
Type unnullable = type.GetNonNullableType();
EmitUnliftedBinaryOp(op, unnullable, unnullable);
_ilg.Emit(OpCodes.Newobj, Nullable_Boolean_Ctor);
_ilg.MarkLabel(end);
}
private void EmitLiftedBinaryArithmetic(ExpressionType op, Type leftType, Type rightType, Type resultType)
{
bool leftIsNullable = leftType.IsNullableType();
bool rightIsNullable = rightType.IsNullableType();
Debug.Assert(leftIsNullable || rightIsNullable);
Label labIfNull = _ilg.DefineLabel();
Label labEnd = _ilg.DefineLabel();
LocalBuilder locLeft = GetLocal(leftType);
LocalBuilder locRight = GetLocal(rightType);
LocalBuilder locResult = GetLocal(resultType);
// store values (reverse order since they are already on the stack)
_ilg.Emit(OpCodes.Stloc, locRight);
_ilg.Emit(OpCodes.Stloc, locLeft);
// test for null
// don't use short circuiting
if (leftIsNullable)
{
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitHasValue(leftType);
}
if (rightIsNullable)
{
_ilg.Emit(OpCodes.Ldloca, locRight);
_ilg.EmitHasValue(rightType);
if (leftIsNullable)
{
_ilg.Emit(OpCodes.And);
}
}
_ilg.Emit(OpCodes.Brfalse_S, labIfNull);
// do op on values
if (leftIsNullable)
{
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitGetValueOrDefault(leftType);
}
else
{
_ilg.Emit(OpCodes.Ldloc, locLeft);
}
if (rightIsNullable)
{
_ilg.Emit(OpCodes.Ldloca, locRight);
_ilg.EmitGetValueOrDefault(rightType);
}
else
{
_ilg.Emit(OpCodes.Ldloc, locRight);
}
//RELEASING locLeft locRight
FreeLocal(locLeft);
FreeLocal(locRight);
Type resultNonNullableType = resultType.GetNonNullableType();
EmitBinaryOperator(op, leftType.GetNonNullableType(), rightType.GetNonNullableType(), resultNonNullableType, liftedToNull: false);
// construct result type
ConstructorInfo ci = TypeUtils.GetNullableConstructor(resultType);
_ilg.Emit(OpCodes.Newobj, ci);
_ilg.Emit(OpCodes.Stloc, locResult);
_ilg.Emit(OpCodes.Br_S, labEnd);
// if null then create a default one
_ilg.MarkLabel(labIfNull);
_ilg.Emit(OpCodes.Ldloca, locResult);
_ilg.Emit(OpCodes.Initobj, resultType);
_ilg.MarkLabel(labEnd);
_ilg.Emit(OpCodes.Ldloc, locResult);
//RELEASING locResult
FreeLocal(locResult);
}
private void EmitLiftedBooleanAnd()
{
Type type = typeof(bool?);
Label returnRight = _ilg.DefineLabel();
Label exit = _ilg.DefineLabel();
// store values (reverse order since they are already on the stack)
LocalBuilder locLeft = GetLocal(type);
LocalBuilder locRight = GetLocal(type);
_ilg.Emit(OpCodes.Stloc, locRight);
_ilg.Emit(OpCodes.Stloc, locLeft);
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitGetValueOrDefault(type);
// if left == true
_ilg.Emit(OpCodes.Brtrue_S, returnRight);
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitHasValue(type);
_ilg.Emit(OpCodes.Ldloca, locRight);
_ilg.EmitGetValueOrDefault(type);
_ilg.Emit(OpCodes.Or);
// if !(left != null | right == true)
_ilg.Emit(OpCodes.Brfalse_S, returnRight);
_ilg.Emit(OpCodes.Ldloc, locLeft);
FreeLocal(locLeft);
_ilg.Emit(OpCodes.Br_S, exit);
_ilg.MarkLabel(returnRight);
_ilg.Emit(OpCodes.Ldloc, locRight);
FreeLocal(locRight);
_ilg.MarkLabel(exit);
}
private void EmitLiftedBooleanOr()
{
Type type = typeof(bool?);
Label returnLeft = _ilg.DefineLabel();
Label exit = _ilg.DefineLabel();
// store values (reverse order since they are already on the stack)
LocalBuilder locLeft = GetLocal(type);
LocalBuilder locRight = GetLocal(type);
_ilg.Emit(OpCodes.Stloc, locRight);
_ilg.Emit(OpCodes.Stloc, locLeft);
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitGetValueOrDefault(type);
// if left == true
_ilg.Emit(OpCodes.Brtrue_S, returnLeft);
_ilg.Emit(OpCodes.Ldloca, locRight);
_ilg.EmitGetValueOrDefault(type);
_ilg.Emit(OpCodes.Ldloca, locLeft);
_ilg.EmitHasValue(type);
_ilg.Emit(OpCodes.Or);
// if !(right == true | left != null)
_ilg.Emit(OpCodes.Brfalse_S, returnLeft);
_ilg.Emit(OpCodes.Ldloc, locRight);
FreeLocal(locRight);
_ilg.Emit(OpCodes.Br_S, exit);
_ilg.MarkLabel(returnLeft);
_ilg.Emit(OpCodes.Ldloc, locLeft);
FreeLocal(locLeft);
_ilg.MarkLabel(exit);
}
}
}
|