File: STL\MathsTools.cs
Web Access
Project: src\src\Microsoft.ML.TimeSeries\Microsoft.ML.TimeSeries.csproj (Microsoft.ML.TimeSeries)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
 
namespace Microsoft.ML.TimeSeries
{
    internal class MathUtility
    {
        /// <summary>
        /// Use quick-sort like method to obtain the median value.
        /// The complexity in expectation is O(n), which is faster than using quickSort.
        /// </summary>
        /// <param name="values">The input list of values. Note that this list will be modified after calling this method</param>
        /// <returns>Returns the median value</returns>
        public static double QuickMedian(List<double> values)
        {
            if (values == null || values.Count == 0)
                return double.NaN;
 
            // here the third parameter is start from 1. so we need to plus 1 to compliant.
            return QuickSelect(values, values.Count / 2 + 1);
        }
 
        /// <summary>
        /// Use quick-sort like method to obtain the median value.
        /// The complexity in expectation is O(n), which is faster than using quickSort.
        /// </summary>
        /// <param name="values">The list of values</param>
        /// <param name="k">The k smallest value in the list</param>
        public static double QuickSelect(IReadOnlyList<double> values, int k)
        {
            var nums = values;
            double[] left = new double[values.Count];
            double[] right = new double[values.Count];
            int numsCount = nums.Count;
 
            while (true)
            {
                if (numsCount == 1)
                    return nums[0];
 
                int idx = FindMedianIndex(nums, 0, numsCount - 1);
                double key = nums[idx];
 
                int leftIdx = 0;
                int rightIdx = 0;
                for (int i = 0; i < numsCount; i++)
                {
                    if (i == idx)
                        continue;
 
                    if (nums[i] < key)
                        left[leftIdx++] = nums[i];
                    else
                        right[rightIdx++] = nums[i];
                }
 
                if (leftIdx == k - 1)
                    return key;
 
                if (leftIdx >= k)
                {
                    nums = left;
                    numsCount = leftIdx;
                }
                else
                {
                    nums = right;
                    k = k - leftIdx - 1;
                    numsCount = rightIdx;
                }
            }
        }
 
        public static int FindMedianIndex(IReadOnlyList<double> values, int start, int end)
        {
            // Use the middle value among first/middle/end as the guard value, to make sure the average performance good.
            // According to unit test, this fix will improve the average performance 10%. and works normally when input list is ordered.
            double first = values[start];
            double last = values[end];
            int midIndex = (start + end) / 2;
            int medianIndex = -1;
            double middleValue = values[midIndex];
            if (first < last)
            {
                if (middleValue > last)
                {
                    // last is the middle value
                    medianIndex = end;
                }
                else if (middleValue > first)
                {
                    // middleValue is the middle value
                    medianIndex = midIndex;
                }
                else
                {
                    // first is the middle value
                    medianIndex = start;
                }
            }
            else
            {
                if (middleValue > first)
                {
                    // first is the middle value
                    medianIndex = start;
                }
                else if (middleValue < last)
                {
                    // last is the middle value
                    medianIndex = end;
                }
                else
                {
                    // middleValue is the middle value
                    medianIndex = midIndex;
                }
            }
            return medianIndex;
        }
    }
}