File: System\Linq\EnumerableQuery.cs
Web Access
Project: src\src\libraries\System.Linq.Queryable\src\System.Linq.Queryable.csproj (System.Linq.Queryable)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
 
namespace System.Linq
{
    public abstract class EnumerableQuery
    {
        internal abstract Expression Expression { get; }
        internal abstract IEnumerable? Enumerable { get; }
 
        internal EnumerableQuery() { }
 
        [RequiresUnreferencedCode(Queryable.InMemoryQueryableExtensionMethodsRequiresUnreferencedCode)]
        [RequiresDynamicCode(Queryable.InMemoryQueryableExtensionMethodsRequiresDynamicCode)]
        internal static IQueryable Create(Type elementType, IEnumerable sequence)
        {
            Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
            return (IQueryable)Activator.CreateInstance(seqType, sequence)!;
        }
 
        [RequiresUnreferencedCode(Queryable.InMemoryQueryableExtensionMethodsRequiresUnreferencedCode)]
        [RequiresDynamicCode(Queryable.InMemoryQueryableExtensionMethodsRequiresDynamicCode)]
        internal static IQueryable Create(Type elementType, Expression expression)
        {
            Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
            return (IQueryable)Activator.CreateInstance(seqType, expression)!;
        }
    }
 
    [RequiresDynamicCode(Queryable.InMemoryQueryableExtensionMethodsRequiresDynamicCode)]
    [RequiresUnreferencedCode(Queryable.InMemoryQueryableExtensionMethodsRequiresUnreferencedCode)]
    public class EnumerableQuery<T> : EnumerableQuery, IOrderedQueryable<T>, IQueryProvider
    {
        private readonly Expression _expression;
        private IEnumerable<T>? _enumerable;
 
        IQueryProvider IQueryable.Provider => this;
 
        public EnumerableQuery(IEnumerable<T> enumerable)
        {
            _enumerable = enumerable;
            _expression = Expression.Constant(this);
        }
 
        public EnumerableQuery(Expression expression)
        {
            _expression = expression;
        }
 
        internal override Expression Expression => _expression;
 
        internal override IEnumerable? Enumerable => _enumerable;
 
        Expression IQueryable.Expression => _expression;
 
        Type IQueryable.ElementType => typeof(T);
 
        IQueryable IQueryProvider.CreateQuery(Expression expression)
        {
            ArgumentNullException.ThrowIfNull(expression);
 
            Type? iqType = TypeHelper.FindGenericType(typeof(IQueryable<>), expression.Type);
            if (iqType == null)
                throw Error.ArgumentNotValid(nameof(expression));
            return Create(iqType.GetGenericArguments()[0], expression);
        }
 
        IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
        {
            ArgumentNullException.ThrowIfNull(expression);
 
            if (!typeof(IQueryable<TElement>).IsAssignableFrom(expression.Type))
            {
                throw Error.ArgumentNotValid(nameof(expression));
            }
            return new EnumerableQuery<TElement>(expression);
        }
 
        object? IQueryProvider.Execute(Expression expression)
        {
            ArgumentNullException.ThrowIfNull(expression);
 
            return EnumerableExecutor.Create(expression).ExecuteBoxed();
        }
 
        TElement IQueryProvider.Execute<TElement>(Expression expression)
        {
            ArgumentNullException.ThrowIfNull(expression);
 
            if (!typeof(TElement).IsAssignableFrom(expression.Type))
                throw Error.ArgumentNotValid(nameof(expression));
            return new EnumerableExecutor<TElement>(expression).Execute();
        }
 
        IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
 
        IEnumerator<T> IEnumerable<T>.GetEnumerator() => GetEnumerator();
 
        private IEnumerator<T> GetEnumerator()
        {
            if (_enumerable == null)
            {
                EnumerableRewriter rewriter = new EnumerableRewriter();
                Expression body = rewriter.Visit(_expression);
                Expression<Func<IEnumerable<T>>> f = Expression.Lambda<Func<IEnumerable<T>>>(body, (IEnumerable<ParameterExpression>?)null);
                IEnumerable<T> enumerable = f.Compile()();
                if (enumerable == this)
                    throw Error.EnumeratingNullEnumerableExpression();
                _enumerable = enumerable;
            }
            return _enumerable.GetEnumerator();
        }
 
        public override string? ToString()
        {
            if (_expression is ConstantExpression c && c.Value == this)
            {
                if (_enumerable != null)
                    return _enumerable.ToString();
                return "null";
            }
            return _expression.ToString();
        }
    }
}