|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using System.Reflection;
namespace System.Linq
{
[RequiresUnreferencedCode(Queryable.InMemoryQueryableExtensionMethodsRequiresUnreferencedCode)]
[RequiresDynamicCode("Requires MakeGenericType")]
internal sealed class EnumerableRewriter : ExpressionVisitor
{
// We must ensure that if a LabelTarget is rewritten that it is always rewritten to the same new target
// or otherwise expressions using it won't match correctly.
private Dictionary<LabelTarget, LabelTarget>? _targetCache;
// Finding equivalent types can be relatively expensive, and hitting with the same types repeatedly is quite likely.
private Dictionary<Type, Type>? _equivalentTypeCache;
public EnumerableRewriter()
{
}
protected override Expression VisitMethodCall(MethodCallExpression m)
{
Expression? obj = Visit(m.Object);
ReadOnlyCollection<Expression> args = Visit(m.Arguments);
// check for args changed
if (obj != m.Object || args != m.Arguments)
{
MethodInfo mInfo = m.Method;
Type[]? typeArgs = (mInfo.IsGenericMethod) ? mInfo.GetGenericArguments() : null;
if ((mInfo.IsStatic || mInfo.DeclaringType!.IsAssignableFrom(obj!.Type))
&& ArgsMatch(mInfo, args, typeArgs))
{
// current method is still valid
return Expression.Call(obj, mInfo, args);
}
else if (mInfo.DeclaringType == typeof(Queryable))
{
// convert Queryable method to Enumerable method
MethodInfo seqMethod = FindEnumerableMethodForQueryable(mInfo.Name, args, typeArgs);
args = FixupQuotedArgs(seqMethod, args);
return Expression.Call(obj, seqMethod, args);
}
else
{
// rebind to new method
MethodInfo method = FindMethod(mInfo.DeclaringType!, mInfo.Name, args, typeArgs);
args = FixupQuotedArgs(method, args);
return Expression.Call(obj, method, args);
}
}
return m;
}
private static ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo mi, ReadOnlyCollection<Expression> argList)
{
ParameterInfo[] pis = mi.GetParameters();
if (pis.Length > 0)
{
List<Expression>? newArgs = null;
for (int i = 0, n = pis.Length; i < n; i++)
{
Expression arg = argList[i];
ParameterInfo pi = pis[i];
arg = FixupQuotedExpression(pi.ParameterType, arg);
if (newArgs == null && arg != argList[i])
{
newArgs = new List<Expression>(argList.Count);
for (int j = 0; j < i; j++)
{
newArgs.Add(argList[j]);
}
}
newArgs?.Add(arg);
}
if (newArgs != null)
argList = newArgs.AsReadOnly();
}
return argList;
}
private static Expression FixupQuotedExpression(Type type, Expression expression)
{
Expression expr = expression;
while (true)
{
if (type.IsAssignableFrom(expr.Type))
return expr;
if (expr.NodeType != ExpressionType.Quote)
break;
expr = ((UnaryExpression)expr).Operand;
}
if (!type.IsAssignableFrom(expr.Type) && type.IsArray && expr.NodeType == ExpressionType.NewArrayInit)
{
Type strippedType = StripExpression(expr.Type);
if (type.IsAssignableFrom(strippedType))
{
Type elementType = type.GetElementType()!;
NewArrayExpression na = (NewArrayExpression)expr;
List<Expression> exprs = new List<Expression>(na.Expressions.Count);
for (int i = 0, n = na.Expressions.Count; i < n; i++)
{
exprs.Add(FixupQuotedExpression(elementType, na.Expressions[i]));
}
expression = Expression.NewArrayInit(elementType, exprs);
}
}
return expression;
}
protected override Expression VisitLambda<T>(Expression<T> node) => node;
private static Type GetPublicType(Type t)
{
// If we create a constant explicitly typed to be a private nested type,
// such as Lookup<,>.Grouping or a compiler-generated iterator class, then
// we cannot use the expression tree in a context which has only execution
// permissions. We should endeavour to translate constants into
// new constants which have public types.
if (t.IsGenericType && t.GetGenericTypeDefinition().GetInterfaces().Contains(typeof(IGrouping<,>)))
return typeof(IGrouping<,>).MakeGenericType(t.GetGenericArguments());
if (!t.IsNestedPrivate)
return t;
foreach (Type iType in t.GetInterfaces())
{
if (iType.IsGenericType && iType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
return iType;
}
if (typeof(IEnumerable).IsAssignableFrom(t))
return typeof(IEnumerable);
return t;
}
private Type GetEquivalentType(Type type)
{
Type? equiv;
// Pre-loading with the non-generic IQueryable and IEnumerable not only covers this case
// without any reflection-based introspection, but also means the slightly different
// code needed to catch this case can be omitted safely.
_equivalentTypeCache ??= new Dictionary<Type, Type>
{
{ typeof(IQueryable), typeof(IEnumerable) },
{ typeof(IEnumerable), typeof(IEnumerable) }
};
if (!_equivalentTypeCache.TryGetValue(type, out equiv))
{
Type pubType = GetPublicType(type);
if (pubType.IsInterface && pubType.IsGenericType)
{
Type genericType = pubType.GetGenericTypeDefinition();
if (genericType == typeof(IOrderedEnumerable<>))
equiv = pubType;
else if (genericType == typeof(IOrderedQueryable<>))
equiv = typeof(IOrderedEnumerable<>).MakeGenericType(pubType.GenericTypeArguments[0]);
else if (genericType == typeof(IEnumerable<>))
equiv = pubType;
else if (genericType == typeof(IQueryable<>))
equiv = typeof(IEnumerable<>).MakeGenericType(pubType.GenericTypeArguments[0]);
}
if (equiv == null)
{
var interfacesWithInfo = pubType.GetInterfaces();
var singleTypeGenInterfacesWithGetType = interfacesWithInfo
.Where(i => i.IsGenericType && i.GenericTypeArguments.Length == 1)
.Select(i => new { Info = i, GenType = i.GetGenericTypeDefinition() })
.ToArray();
Type? typeArg = singleTypeGenInterfacesWithGetType
.Where(i => i.GenType == typeof(IOrderedQueryable<>) || i.GenType == typeof(IOrderedEnumerable<>))
.Select(i => i.Info.GenericTypeArguments[0])
.Distinct()
.SingleOrDefault();
if (typeArg != null)
equiv = typeof(IOrderedEnumerable<>).MakeGenericType(typeArg);
else
{
typeArg = singleTypeGenInterfacesWithGetType
.Where(i => i.GenType == typeof(IQueryable<>) || i.GenType == typeof(IEnumerable<>))
.Select(i => i.Info.GenericTypeArguments[0])
.Distinct()
.Single();
equiv = typeof(IEnumerable<>).MakeGenericType(typeArg);
}
}
_equivalentTypeCache.Add(type, equiv);
}
return equiv;
}
protected override Expression VisitConstant(ConstantExpression c)
{
if (c.Value is EnumerableQuery sq)
{
if (sq.Enumerable != null)
{
Type t = GetPublicType(sq.Enumerable.GetType());
return Expression.Constant(sq.Enumerable, t);
}
Expression exp = sq.Expression;
if (exp != c)
return Visit(exp);
}
return c;
}
private static ILookup<string, MethodInfo>? s_seqMethods;
private static MethodInfo FindEnumerableMethodForQueryable(string name, ReadOnlyCollection<Expression> args, params Type[]? typeArgs)
{
s_seqMethods ??= GetEnumerableStaticMethods(typeof(Enumerable)).ToLookup(m => m.Name);
MethodInfo[] matchingMethods = s_seqMethods[name]
.Where(m => ArgsMatch(m, args, typeArgs))
.Select(ApplyTypeArgs)
.ToArray();
Debug.Assert(matchingMethods.Length > 0, "All static methods with arguments on Queryable have equivalents on Enumerable.");
if (matchingMethods.Length > 1)
{
return DisambiguateMatches(matchingMethods);
}
return matchingMethods[0];
static MethodInfo[] GetEnumerableStaticMethods(Type type) =>
type.GetMethods(BindingFlags.Public | BindingFlags.Static);
[RequiresDynamicCodeAttribute("Calls System.Reflection.MethodInfo.MakeGenericMethod(params Type[])")]
MethodInfo ApplyTypeArgs(MethodInfo methodInfo) => typeArgs == null ? methodInfo : methodInfo.MakeGenericMethod(typeArgs);
// In certain cases, there might be ambiguities when resolving matching overloads, for example between
// 1. FirstOrDefault<object>(IEnumerable<object> source, Func<object, bool> predicate) and
// 2. FirstOrDefault<object>(IEnumerable<object> source, object defaultvalue).
// In such cases we disambiguate by picking a method with the most derived signature.
static MethodInfo DisambiguateMatches(MethodInfo[] matchingMethods)
{
Debug.Assert(matchingMethods.Length > 1);
ParameterInfo[][] parameters = matchingMethods.Select(m => m.GetParameters()).ToArray();
// `AreAssignableFrom[Strict]` defines a partial order on method signatures; pick a maximal element using that order.
// It is assumed that `matchingMethods` is a small array, so a naive quadratic search is probably better than
// doing some variant of topological sorting.
for (int i = 0; i < matchingMethods.Length; i++)
{
bool isMaximal = true;
for (int j = 0; j < matchingMethods.Length; j++)
{
if (i != j && AreAssignableFromStrict(parameters[i], parameters[j]))
{
// Found a matching method that contains strictly more specific parameter types.
isMaximal = false;
break;
}
}
if (isMaximal)
{
return matchingMethods[i];
}
}
Debug.Fail("Search should have found a maximal element");
throw new Exception();
static bool AreAssignableFromStrict(ParameterInfo[] left, ParameterInfo[] right)
{
Debug.Assert(left.Length == right.Length);
bool areEqual = true;
bool areAssignableFrom = true;
for (int i = 0; i < left.Length; i++)
{
Type leftParam = left[i].ParameterType;
Type rightParam = right[i].ParameterType;
areEqual = areEqual && leftParam == rightParam;
areAssignableFrom = areAssignableFrom && leftParam.IsAssignableFrom(rightParam);
}
return !areEqual && areAssignableFrom;
}
}
}
[RequiresUnreferencedCode(Queryable.InMemoryQueryableExtensionMethodsRequiresUnreferencedCode)]
[RequiresDynamicCode("Calls System.Reflection.MethodInfo.MakeGenericMethod(params Type[])")]
private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[]? typeArgs)
{
using (IEnumerator<MethodInfo> en = type.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static).Where(m => m.Name == name).GetEnumerator())
{
if (!en.MoveNext())
throw Error.NoMethodOnType(name, type);
do
{
MethodInfo mi = en.Current;
if (ArgsMatch(mi, args, typeArgs))
return (typeArgs != null) ? mi.MakeGenericMethod(typeArgs) : mi;
} while (en.MoveNext());
}
throw Error.NoMethodOnTypeMatchingArguments(name, type);
}
private static bool ArgsMatch(MethodInfo m, ReadOnlyCollection<Expression> args, Type[]? typeArgs)
{
ParameterInfo[] mParams = m.GetParameters();
if (mParams.Length != args.Count)
return false;
if (!m.IsGenericMethod && typeArgs != null && typeArgs.Length > 0)
{
return false;
}
if (!m.IsGenericMethodDefinition && m.IsGenericMethod && m.ContainsGenericParameters)
{
m = m.GetGenericMethodDefinition();
}
if (m.IsGenericMethodDefinition)
{
if (typeArgs == null || typeArgs.Length == 0)
return false;
if (m.GetGenericArguments().Length != typeArgs.Length)
return false;
mParams = GetConstrutedGenericParameters(m, typeArgs);
[RequiresDynamicCodeAttribute("Calls System.Reflection.MethodInfo.MakeGenericMethod(params Type[])")]
static ParameterInfo[] GetConstrutedGenericParameters(MethodInfo method, Type[] genericTypes) =>
method.MakeGenericMethod(genericTypes).GetParameters();
}
for (int i = 0, n = args.Count; i < n; i++)
{
Type parameterType = mParams[i].ParameterType;
if (parameterType == null)
return false;
if (parameterType.IsByRef)
parameterType = parameterType.GetElementType()!;
Expression arg = args[i];
if (!parameterType.IsAssignableFrom(arg.Type))
{
if (arg.NodeType == ExpressionType.Quote)
{
arg = ((UnaryExpression)arg).Operand;
}
if (!parameterType.IsAssignableFrom(arg.Type) &&
!parameterType.IsAssignableFrom(StripExpression(arg.Type)))
{
return false;
}
}
}
return true;
}
[RequiresDynamicCode("Calls System.Type.MakeArrayType()")]
private static Type StripExpression(Type type)
{
bool isArray = type.IsArray;
Type tmp = isArray ? type.GetElementType()! : type;
Type? eType = TypeHelper.FindGenericType(typeof(Expression<>), tmp);
if (eType != null)
tmp = eType.GetGenericArguments()[0];
if (isArray)
{
int rank = type.GetArrayRank();
return (rank == 1) ? tmp.MakeArrayType() : tmp.MakeArrayType(rank);
}
return type;
}
protected override Expression VisitConditional(ConditionalExpression c)
{
Type type = c.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
return base.VisitConditional(c);
Expression test = Visit(c.Test);
Expression ifTrue = Visit(c.IfTrue);
Expression ifFalse = Visit(c.IfFalse);
Type trueType = ifTrue.Type;
Type falseType = ifFalse.Type;
if (trueType.IsAssignableFrom(falseType))
return Expression.Condition(test, ifTrue, ifFalse, trueType);
if (falseType.IsAssignableFrom(trueType))
return Expression.Condition(test, ifTrue, ifFalse, falseType);
return Expression.Condition(test, ifTrue, ifFalse, GetEquivalentType(type));
}
protected override Expression VisitBlock(BlockExpression node)
{
Type type = node.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
return base.VisitBlock(node);
ReadOnlyCollection<Expression> nodes = Visit(node.Expressions);
ReadOnlyCollection<ParameterExpression> variables = VisitAndConvert(node.Variables, "EnumerableRewriter.VisitBlock");
if (type == node.Expressions.Last().Type)
return Expression.Block(variables, nodes);
return Expression.Block(GetEquivalentType(type), variables, nodes);
}
protected override Expression VisitGoto(GotoExpression node)
{
Type type = node.Value!.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
return base.VisitGoto(node);
LabelTarget target = VisitLabelTarget(node.Target);
Expression value = Visit(node.Value);
return Expression.MakeGoto(node.Kind, target, value, GetEquivalentType(typeof(EnumerableQuery).IsAssignableFrom(type) ? value.Type : type));
}
protected override LabelTarget VisitLabelTarget(LabelTarget? node)
{
LabelTarget? newTarget;
if (_targetCache == null)
{
_targetCache = new Dictionary<LabelTarget, LabelTarget>();
}
else if (_targetCache.TryGetValue(node!, out newTarget))
{
return newTarget;
}
Type type = node!.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
newTarget = base.VisitLabelTarget(node);
else
newTarget = Expression.Label(GetEquivalentType(type), node.Name);
_targetCache.Add(node, newTarget);
return newTarget;
}
}
}
|