File: System\Linq\AggregateBy.cs
Web Access
Project: src\src\libraries\System.Linq\src\System.Linq.csproj (System.Linq)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Runtime.InteropServices;
 
namespace System.Linq
{
    public static partial class Enumerable
    {
        /// <summary>
        /// Applies an accumulator function over a sequence, grouping results by key.
        /// </summary>
        /// <typeparam name="TSource">The type of the elements of <paramref name="source"/>.</typeparam>
        /// <typeparam name="TKey">The type of the key returned by <paramref name="keySelector"/>.</typeparam>
        /// <typeparam name="TAccumulate">The type of the accumulator value.</typeparam>
        /// <param name="source">An <see cref="IEnumerable{T}"/> to aggregate over.</param>
        /// <param name="keySelector">A function to extract the key for each element.</param>
        /// <param name="seed">The initial accumulator value.</param>
        /// <param name="func">An accumulator function to be invoked on each element.</param>
        /// <param name="keyComparer">An <see cref="IEqualityComparer{T}"/> to compare keys with.</param>
        /// <returns>An enumerable containing the aggregates corresponding to each key deriving from <paramref name="source"/>.</returns>
        /// <remarks>
        /// This method is comparable to the <see cref="GroupBy{TSource, TKey}(IEnumerable{TSource}, Func{TSource, TKey})"/> methods
        /// where each grouping is being aggregated into a single value as opposed to allocating a collection for each group.
        /// </remarks>
        public static IEnumerable<KeyValuePair<TKey, TAccumulate>> AggregateBy<TSource, TKey, TAccumulate>(
            this IEnumerable<TSource> source,
            Func<TSource, TKey> keySelector,
            TAccumulate seed,
            Func<TAccumulate, TSource, TAccumulate> func,
            IEqualityComparer<TKey>? keyComparer = null) where TKey : notnull
        {
            if (source is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
            }
            if (keySelector is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector);
            }
            if (func is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.func);
            }
 
            if (IsEmptyArray(source))
            {
                return [];
            }
 
            return AggregateByIterator(source, keySelector, seed, func, keyComparer);
        }
 
        /// <summary>
        /// Applies an accumulator function over a sequence, grouping results by key.
        /// </summary>
        /// <typeparam name="TSource">The type of the elements of <paramref name="source"/>.</typeparam>
        /// <typeparam name="TKey">The type of the key returned by <paramref name="keySelector"/>.</typeparam>
        /// <typeparam name="TAccumulate">The type of the accumulator value.</typeparam>
        /// <param name="source">An <see cref="IEnumerable{T}"/> to aggregate over.</param>
        /// <param name="keySelector">A function to extract the key for each element.</param>
        /// <param name="seedSelector">A factory for the initial accumulator value.</param>
        /// <param name="func">An accumulator function to be invoked on each element.</param>
        /// <param name="keyComparer">An <see cref="IEqualityComparer{T}"/> to compare keys with.</param>
        /// <returns>An enumerable containing the aggregates corresponding to each key deriving from <paramref name="source"/>.</returns>
        /// <remarks>
        /// This method is comparable to the <see cref="GroupBy{TSource, TKey}(IEnumerable{TSource}, Func{TSource, TKey})"/> methods
        /// where each grouping is being aggregated into a single value as opposed to allocating a collection for each group.
        /// </remarks>
        public static IEnumerable<KeyValuePair<TKey, TAccumulate>> AggregateBy<TSource, TKey, TAccumulate>(
            this IEnumerable<TSource> source,
            Func<TSource, TKey> keySelector,
            Func<TKey, TAccumulate> seedSelector,
            Func<TAccumulate, TSource, TAccumulate> func,
            IEqualityComparer<TKey>? keyComparer = null) where TKey : notnull
        {
            if (source is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
            }
            if (keySelector is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector);
            }
            if (seedSelector is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.seedSelector);
            }
            if (func is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.func);
            }
 
            if (IsEmptyArray(source))
            {
                return [];
            }
 
            return AggregateByIterator(source, keySelector, seedSelector, func, keyComparer);
        }
 
        private static IEnumerable<KeyValuePair<TKey, TAccumulate>> AggregateByIterator<TSource, TKey, TAccumulate>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func, IEqualityComparer<TKey>? keyComparer) where TKey : notnull
        {
            using IEnumerator<TSource> enumerator = source.GetEnumerator();
 
            if (!enumerator.MoveNext())
            {
                yield break;
            }
 
            foreach (KeyValuePair<TKey, TAccumulate> countBy in PopulateDictionary(enumerator, keySelector, seed, func, keyComparer))
            {
                yield return countBy;
            }
 
            static Dictionary<TKey, TAccumulate> PopulateDictionary(IEnumerator<TSource> enumerator, Func<TSource, TKey> keySelector, TAccumulate seed, Func<TAccumulate, TSource, TAccumulate> func, IEqualityComparer<TKey>? keyComparer)
            {
                Dictionary<TKey, TAccumulate> dict = new(keyComparer);
 
                do
                {
                    TSource value = enumerator.Current;
                    TKey key = keySelector(value);
 
                    ref TAccumulate? acc = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, key, out bool exists);
                    acc = func(exists ? acc! : seed, value);
                }
                while (enumerator.MoveNext());
 
                return dict;
            }
        }
 
        private static IEnumerable<KeyValuePair<TKey, TAccumulate>> AggregateByIterator<TSource, TKey, TAccumulate>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TKey, TAccumulate> seedSelector, Func<TAccumulate, TSource, TAccumulate> func, IEqualityComparer<TKey>? keyComparer) where TKey : notnull
        {
            using IEnumerator<TSource> enumerator = source.GetEnumerator();
 
            if (!enumerator.MoveNext())
            {
                yield break;
            }
 
            foreach (KeyValuePair<TKey, TAccumulate> countBy in PopulateDictionary(enumerator, keySelector, seedSelector, func, keyComparer))
            {
                yield return countBy;
            }
 
            static Dictionary<TKey, TAccumulate> PopulateDictionary(IEnumerator<TSource> enumerator, Func<TSource, TKey> keySelector, Func<TKey, TAccumulate> seedSelector, Func<TAccumulate, TSource, TAccumulate> func, IEqualityComparer<TKey>? keyComparer)
            {
                Dictionary<TKey, TAccumulate> dict = new(keyComparer);
 
                do
                {
                    TSource value = enumerator.Current;
                    TKey key = keySelector(value);
 
                    ref TAccumulate? acc = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, key, out bool exists);
                    acc = func(exists ? acc! : seedSelector(key), value);
                }
                while (enumerator.MoveNext());
 
                return dict;
            }
        }
    }
}