File: System\Linq\CountBy.cs
Web Access
Project: src\src\libraries\System.Linq.AsyncEnumerable\src\System.Linq.AsyncEnumerable.csproj (System.Linq.AsyncEnumerable)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
 
namespace System.Linq
{
    public static partial class AsyncEnumerable
    {
        /// <summary>Returns the count of elements in the source sequence grouped by key.</summary>
        /// <typeparam name="TSource">The type of elements of <paramref name="source"/>.</typeparam>
        /// <typeparam name="TKey">The type of the key returned by <paramref name="keySelector"/>.</typeparam>
        /// <param name="source">A sequence that contains elements to be counted.</param>
        /// <param name="keySelector">A function to extract the key for each element.</param>
        /// <param name="keyComparer">An <see cref="IEqualityComparer{T}"/> to compare keys with.</param>
        /// <returns>An enumerable containing the frequencies of each key occurrence in <paramref name="source"/>.</returns>
        /// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
        /// <exception cref="ArgumentNullException"><paramref name="keySelector"/> is <see langword="null"/>.</exception>
        public static IAsyncEnumerable<KeyValuePair<TKey, int>> CountBy<TSource, TKey>(
            this IAsyncEnumerable<TSource> source,
            Func<TSource, TKey> keySelector,
            IEqualityComparer<TKey>? keyComparer = null) where TKey : notnull
        {
            ThrowHelper.ThrowIfNull(source);
            ThrowHelper.ThrowIfNull(keySelector);
 
            return
                source.IsKnownEmpty() ? Empty<KeyValuePair<TKey, int>>() :
                Impl(source, keySelector, keyComparer, default);
 
            static async IAsyncEnumerable<KeyValuePair<TKey, int>> Impl(
                IAsyncEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? keyComparer, [EnumeratorCancellation] CancellationToken cancellationToken)
            {
                IAsyncEnumerator<TSource> enumerator = source.GetAsyncEnumerator(cancellationToken);
                try
                {
                    if (await enumerator.MoveNextAsync().ConfigureAwait(false))
                    {
                        Dictionary<TKey, int> countsBy = new(keyComparer);
                        do
                        {
                            TSource value = enumerator.Current;
                            TKey key = keySelector(value);
 
#if NET
                            ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, key, out _);
                            checked { currentCount++; }
#else
                            countsBy[key] = countsBy.TryGetValue(key, out int currentCount) ? checked(currentCount + 1) : 1;
#endif
                        }
                        while (await enumerator.MoveNextAsync().ConfigureAwait(false));
 
                        foreach (KeyValuePair<TKey, int> countBy in countsBy)
                        {
                            yield return countBy;
                        }
                    }
                }
                finally
                {
                    await enumerator.DisposeAsync().ConfigureAwait(false);
                }
            }
        }
 
        /// <summary>Returns the count of elements in the source sequence grouped by key.</summary>
        /// <typeparam name="TSource">The type of elements of <paramref name="source"/>.</typeparam>
        /// <typeparam name="TKey">The type of the key returned by <paramref name="keySelector"/>.</typeparam>
        /// <param name="source">A sequence that contains elements to be counted.</param>
        /// <param name="keySelector">A function to extract the key for each element.</param>
        /// <param name="keyComparer">An <see cref="IEqualityComparer{T}"/> to compare keys with.</param>
        /// <returns>An enumerable containing the frequencies of each key occurrence in <paramref name="source"/>.</returns>
        /// <exception cref="ArgumentNullException"><paramref name="source"/> is <see langword="null"/>.</exception>
        /// <exception cref="ArgumentNullException"><paramref name="keySelector"/> is <see langword="null"/>.</exception>
        public static IAsyncEnumerable<KeyValuePair<TKey, int>> CountBy<TSource, TKey>(
            this IAsyncEnumerable<TSource> source,
            Func<TSource, CancellationToken, ValueTask<TKey>> keySelector,
            IEqualityComparer<TKey>? keyComparer = null) where TKey : notnull
        {
            ThrowHelper.ThrowIfNull(source);
            ThrowHelper.ThrowIfNull(keySelector);
 
            return
                source.IsKnownEmpty() ? Empty<KeyValuePair<TKey, int>>() :
                Impl(source, keySelector, keyComparer, default);
 
            static async IAsyncEnumerable<KeyValuePair<TKey, int>> Impl(
                IAsyncEnumerable<TSource> source, Func<TSource, CancellationToken, ValueTask<TKey>> keySelector, IEqualityComparer<TKey>? keyComparer, [EnumeratorCancellation] CancellationToken cancellationToken)
            {
                IAsyncEnumerator<TSource> enumerator = source.GetAsyncEnumerator(cancellationToken);
                try
                {
                    if (await enumerator.MoveNextAsync().ConfigureAwait(false))
                    {
                        Dictionary<TKey, int> countsBy = new(keyComparer);
                        do
                        {
                            TSource value = enumerator.Current;
                            TKey key = await keySelector(value, cancellationToken).ConfigureAwait(false);
 
#if NET
                            ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, key, out _);
                            checked { currentCount++; }
#else
                            countsBy[key] = countsBy.TryGetValue(key, out int currentCount) ? checked(currentCount + 1) : 1;
#endif
                        }
                        while (await enumerator.MoveNextAsync().ConfigureAwait(false));
 
                        foreach (KeyValuePair<TKey, int> countBy in countsBy)
                        {
                            yield return countBy;
                        }
                    }
                }
                finally
                {
                    await enumerator.DisposeAsync().ConfigureAwait(false);
                }
            }
        }
    }
}