File: System\Linq\Average.cs
Web Access
Project: src\src\libraries\System.Linq\src\System.Linq.csproj (System.Linq)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
 
using System.Collections.Generic;
using System.Numerics;
 
namespace System.Linq
{
    public static partial class Enumerable
    {
        public static double Average(this IEnumerable<int> source)
        {
            if (source is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
            }
 
            if (source.TryGetSpan(out ReadOnlySpan<int> span))
            {
                // Int32 is special-cased separately from the rest of the types as it can be vectorized:
                // with at most Int32.MaxValue values, and with each being at most Int32.MaxValue, we can't
                // overflow a long accumulator, and order of operations doesn't matter.
 
                if (span.IsEmpty)
                {
                    ThrowHelper.ThrowNoElementsException();
                }
 
                long sum = 0;
                int i = 0;
 
                if (Vector.IsHardwareAccelerated && span.Length >= Vector<int>.Count)
                {
                    Vector<long> sums = default;
                    do
                    {
                        Vector.Widen(new Vector<int>(span.Slice(i)), out Vector<long> low, out Vector<long> high);
                        sums += low;
                        sums += high;
                        i += Vector<int>.Count;
                    }
                    while (i <= span.Length - Vector<int>.Count);
                    sum += Vector.Sum(sums);
                }
 
                for (; (uint)i < (uint)span.Length; i++)
                {
                    sum += span[i];
                }
 
                return (double)sum / span.Length;
            }
 
            using (IEnumerator<int> e = source.GetEnumerator())
            {
                if (!e.MoveNext())
                {
                    ThrowHelper.ThrowNoElementsException();
                }
 
                long sum = e.Current;
                long count = 1;
 
                while (e.MoveNext())
                {
                    checked { sum += e.Current; }
                    count++;
                }
 
                return (double)sum / count;
            }
        }
 
        public static double Average(this IEnumerable<long> source) => Average<long, long, double>(source);
 
        public static float Average(this IEnumerable<float> source) => (float)Average<float, double, double>(source);
 
        public static double Average(this IEnumerable<double> source) => Average<double, double, double>(source);
 
        public static decimal Average(this IEnumerable<decimal> source) => Average<decimal, decimal, decimal>(source);
 
        private static TResult Average<TSource, TAccumulator, TResult>(this IEnumerable<TSource> source)
            where TSource : struct, INumber<TSource>
            where TAccumulator : struct, INumber<TAccumulator>
            where TResult : struct, INumber<TResult>
        {
            if (source is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
            }
 
            if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
            {
                if (span.IsEmpty)
                {
                    ThrowHelper.ThrowNoElementsException();
                }
 
                return TResult.CreateChecked(Sum<TSource, TAccumulator>(span)) / TResult.CreateChecked(span.Length);
            }
 
            using (IEnumerator<TSource> e = source.GetEnumerator())
            {
                if (!e.MoveNext())
                {
                    ThrowHelper.ThrowNoElementsException();
                }
 
                TAccumulator sum = TAccumulator.CreateChecked(e.Current);
                long count = 1;
                while (e.MoveNext())
                {
                    checked { sum += TAccumulator.CreateChecked(e.Current); }
                    count++;
                }
 
                return TResult.CreateChecked(sum) / TResult.CreateChecked(count);
            }
        }
 
 
        public static double? Average(this IEnumerable<int?> source) => Average<int, long, double>(source);
 
        public static double? Average(this IEnumerable<long?> source) => Average<long, long, double>(source);
 
        public static float? Average(this IEnumerable<float?> source) => Average<float, double, double>(source) is double result ? (float)result : null;
 
        public static double? Average(this IEnumerable<double?> source) => Average<double, double, double>(source);
 
        public static decimal? Average(this IEnumerable<decimal?> source) => Average<decimal, decimal, decimal>(source);
 
        private static TResult? Average<TSource, TAccumulator, TResult>(this IEnumerable<TSource?> source)
            where TSource : struct, INumber<TSource>
            where TAccumulator : struct, INumber<TAccumulator>
            where TResult : struct, INumber<TResult>
        {
            if (source is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
            }
 
            using (IEnumerator<TSource?> e = source.GetEnumerator())
            {
                while (e.MoveNext())
                {
                    TSource? value = e.Current;
                    if (value.HasValue)
                    {
                        TAccumulator sum = TAccumulator.CreateChecked(value.GetValueOrDefault());
                        long count = 1;
 
                        while (e.MoveNext())
                        {
                            value = e.Current;
                            if (value.HasValue)
                            {
                                checked { sum += TAccumulator.CreateChecked(value.GetValueOrDefault()); }
                                count++;
                            }
                        }
 
                        return TResult.CreateChecked(sum) / TResult.CreateChecked(count);
                    }
                }
            }
 
            return null;
        }
 
 
        public static double Average<TSource>(this IEnumerable<TSource> source, Func<TSource, int> selector) => Average<TSource, int, long, double>(source, selector);
 
        public static double Average<TSource>(this IEnumerable<TSource> source, Func<TSource, long> selector) => Average<TSource, long, long, double>(source, selector);
 
        public static float Average<TSource>(this IEnumerable<TSource> source, Func<TSource, float> selector) => (float)Average<TSource, float, double, double>(source, selector);
 
        public static double Average<TSource>(this IEnumerable<TSource> source, Func<TSource, double> selector) => Average<TSource, double, double, double>(source, selector);
 
        public static decimal Average<TSource>(this IEnumerable<TSource> source, Func<TSource, decimal> selector) => Average<TSource, decimal, decimal, decimal>(source, selector);
 
        private static TResult Average<TSource, TSelector, TAccumulator, TResult>(this IEnumerable<TSource> source, Func<TSource, TSelector> selector)
            where TSelector : struct, INumber<TSelector>
            where TAccumulator : struct, INumber<TAccumulator>
            where TResult : struct, INumber<TResult>
        {
            if (source is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
            }
 
            if (selector is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.selector);
            }
 
            using (IEnumerator<TSource> e = source.GetEnumerator())
            {
                if (!e.MoveNext())
                {
                    ThrowHelper.ThrowNoElementsException();
                }
 
                TAccumulator sum = TAccumulator.CreateChecked(selector(e.Current));
                long count = 1;
 
                while (e.MoveNext())
                {
                    checked { sum += TAccumulator.CreateChecked(selector(e.Current)); }
                    count++;
                }
 
                return TResult.CreateChecked(sum) / TResult.CreateChecked(count);
            }
        }
 
 
        public static double? Average<TSource>(this IEnumerable<TSource> source, Func<TSource, int?> selector) => Average<TSource, int, long, double>(source, selector);
 
        public static double? Average<TSource>(this IEnumerable<TSource> source, Func<TSource, long?> selector) => Average<TSource, long, long, double>(source, selector);
 
        public static float? Average<TSource>(this IEnumerable<TSource> source, Func<TSource, float?> selector) => Average<TSource, float, double, double>(source, selector) is double result ? (float)result : null;
 
        public static double? Average<TSource>(this IEnumerable<TSource> source, Func<TSource, double?> selector) => Average<TSource, double, double, double>(source, selector);
 
        public static decimal? Average<TSource>(this IEnumerable<TSource> source, Func<TSource, decimal?> selector) => Average<TSource, decimal, decimal, decimal>(source, selector);
 
        private static TResult? Average<TSource, TSelector, TAccumulator, TResult>(this IEnumerable<TSource> source, Func<TSource, TSelector?> selector)
            where TSelector : struct, INumber<TSelector>
            where TAccumulator : struct, INumber<TAccumulator>
            where TResult : struct, INumber<TResult>
        {
            if (source is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
            }
 
            if (selector is null)
            {
                ThrowHelper.ThrowArgumentNullException(ExceptionArgument.selector);
            }
 
            using (IEnumerator<TSource> e = source.GetEnumerator())
            {
                while (e.MoveNext())
                {
                    TSelector? value = selector(e.Current);
                    if (value.HasValue)
                    {
                        TAccumulator sum = TAccumulator.CreateChecked(value.GetValueOrDefault());
                        long count = 1;
 
                        while (e.MoveNext())
                        {
                            value = selector(e.Current);
                            if (value.HasValue)
                            {
                                checked { sum += TAccumulator.CreateChecked(value.GetValueOrDefault()); }
                                count++;
                            }
                        }
 
                        return TResult.CreateChecked(sum) / TResult.CreateChecked(count);
                    }
                }
            }
 
            return null;
        }
    }
}