File: System\Linq\OfType.SpeedOpt.cs
Web Access
Project: src\src\libraries\System.Linq\src\System.Linq.csproj (System.Linq)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
 
namespace System.Linq
{
    public static partial class Enumerable
    {
        private sealed partial class OfTypeIterator<TResult>
        {
            public override int GetCount(bool onlyIfCheap)
            {
                if (onlyIfCheap)
                {
                    return -1;
                }
 
                int count = 0;
                foreach (object? item in _source)
                {
                    if (item is TResult)
                    {
                        checked { count++; }
                    }
                }
 
                return count;
            }
 
            public override TResult[] ToArray()
            {
                SegmentedArrayBuilder<TResult>.ScratchBuffer scratch = default;
                SegmentedArrayBuilder<TResult> builder = new(scratch);
 
                foreach (object? item in _source)
                {
                    if (item is TResult castItem)
                    {
                        builder.Add(castItem);
                    }
                }
 
                TResult[] result = builder.ToArray();
                builder.Dispose();
 
                return result;
            }
 
            public override List<TResult> ToList()
            {
                SegmentedArrayBuilder<TResult>.ScratchBuffer scratch = default;
                SegmentedArrayBuilder<TResult> builder = new(scratch);
 
                foreach (object? item in _source)
                {
                    if (item is TResult castItem)
                    {
                        builder.Add(castItem);
                    }
                }
 
                List<TResult> result = builder.ToList();
                builder.Dispose();
 
                return result;
            }
 
            public override TResult? TryGetFirst(out bool found)
            {
                foreach (object? item in _source)
                {
                    if (item is TResult castItem)
                    {
                        found = true;
                        return castItem;
                    }
                }
 
                found = false;
                return default;
            }
 
            public override TResult? TryGetLast(out bool found)
            {
                IEnumerator e = _source.GetEnumerator();
                try
                {
                    if (e.MoveNext())
                    {
                        do
                        {
                            if (e.Current is TResult last)
                            {
                                found = true;
 
                                while (e.MoveNext())
                                {
                                    if (e.Current is TResult castCurrent)
                                    {
                                        last = castCurrent;
                                    }
                                }
 
                                return last;
                            }
                        }
                        while (e.MoveNext());
                    }
                }
                finally
                {
                    (e as IDisposable)?.Dispose();
                }
 
                found = false;
                return default;
            }
 
            public override TResult? TryGetElementAt(int index, out bool found)
            {
                if (index >= 0)
                {
                    foreach (object? item in _source)
                    {
                        if (item is TResult castItem)
                        {
                            if (index == 0)
                            {
                                found = true;
                                return castItem;
                            }
 
                            index--;
                        }
                    }
                }
 
                found = false;
                return default;
            }
 
            public override IEnumerable<TResult2> Select<TResult2>(Func<TResult, TResult2> selector)
            {
                // If the source is any generic enumerable of a reference type, which should be the 90% case, it'll covariantly
                // implement IEnumerable<object>, and we can optimize the OfType().Select case by treating the OfType instead like
                // a Where, using the same WhereSelectIterators that are used for Where.Select.
                if (!typeof(TResult).IsValueType && _source is IEnumerable<object> objectSource)
                {
                    // Unsafe.As here is safe because we're only dealing with reference types, and we know by construction that only
                    // TResult instances will be passed in. Using Unsafe.As allows us to avoid an extra closure and delegate allocation.
                    Func<object, TResult2> localSelector =
#if DEBUG
                        o =>
                        {
                            Debug.Assert(o is TResult);
                            return selector((TResult)o);
                        };
#else
                        Unsafe.As<Func<object, TResult2>>(selector);
#endif
 
                    // We can special-case arrays and IEnumerable to use the corresponding WhereSelectIterators because
                    // they're covariant. It's not worthwhile checking for List<T> to use the ListWhereSelectIterator
                    // because List<> is not covariant.
                    Func<object, bool> isTResult = static o => o is TResult;
                    return objectSource is object[] array ?
                        new ArrayWhereSelectIterator<object, TResult2>(array, isTResult, localSelector) :
                        new IEnumerableWhereSelectIterator<object, TResult2>(objectSource, isTResult, localSelector);
                }
 
                return base.Select(selector);
            }
        }
    }
}