File: Algorithms\SweeperProbabilityUtils.cs
Web Access
Project: src\src\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj (Microsoft.ML.Sweeper)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML.Internal.CpuMath;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Sweeper.Algorithms
{
    public sealed class SweeperProbabilityUtils
    {
        private readonly Random _rng;
 
        public SweeperProbabilityUtils(IHost host)
        {
            _rng = new Random(host.Rand.Next());
        }
 
        public static double Sum(double[] a)
        {
            double total = 0;
            foreach (double d in a)
                total += d;
            return total;
        }
 
        public static double NormalPdf(double x, double mean, double variance)
        {
            const double minVariance = 1e-200;
            variance = Math.Max(variance, minVariance);
            return 1 / Math.Sqrt(2 * Math.PI * variance) * Math.Exp(-Math.Pow(x - mean, 2) / (2 * variance));
        }
 
        public static double NormalCdf(double x, double mean, double variance)
        {
            double centered = x - mean;
            double ztrans = centered / (Math.Sqrt(variance) * Math.Sqrt(2));
 
            return 0.5 * (1 + ProbabilityFunctions.Erf(ztrans));
        }
 
        public static double StdNormalPdf(double x)
        {
            return 1 / Math.Sqrt(2 * Math.PI) * Math.Exp(-Math.Pow(x, 2) / 2);
        }
 
        public static double StdNormalCdf(double x)
        {
            return 0.5 * (1 + ProbabilityFunctions.Erf(x * 1 / Math.Sqrt(2)));
        }
 
        /// <summary>
        /// Samples from a Gaussian Normal with mean mu and std dev sigma.
        /// </summary>
        /// <param name="numRVs">Number of samples</param>
        /// <param name="mu">mean</param>
        /// <param name="sigma">standard deviation</param>
        /// <returns></returns>
        public double[] NormalRVs(int numRVs, double mu, double sigma)
        {
            List<double> rvs = new List<double>();
            double u1;
            double u2;
 
            for (int i = 0; i < numRVs; i++)
            {
                u1 = _rng.NextDouble();
                u2 = _rng.NextDouble();
                rvs.Add(mu + sigma * Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Sin(2.0 * Math.PI * u2));
            }
 
            return rvs.ToArray();
        }
 
        /// <summary>
        /// This performs (slow) roulette-wheel sampling of a categorical distribution. Should be swapped for other
        /// method as soon as one is available.
        /// </summary>
        /// <param name="numSamples">Number of samples to draw.</param>
        /// <param name="weights">Weights for distribution (should sum to 1).</param>
        /// <returns>A set of indicies indicating which element was chosen for each sample.</returns>
        public int[] SampleCategoricalDistribution(int numSamples, double[] weights)
        {
            // Normalize weights if necessary.
            double total = Sum(weights);
            if (Math.Abs(1.0 - total) > 0.0001)
                weights = Normalize(weights);
 
            // Build roulette wheel.
            double[] rw = new double[weights.Length];
            double cs = 0.0;
            for (int i = 0; i < weights.Length; i++)
            {
                cs += weights[i];
                rw[i] = cs;
            }
 
            // Draw samples.
            int[] results = new int[numSamples];
            for (int i = 0; i < results.Length; i++)
            {
                double u = _rng.NextDouble();
                results[i] = BinarySearch(rw, u, 0, rw.Length - 1);
            }
 
            return results;
        }
 
        public double SampleUniform()
        {
            return _rng.NextDouble();
        }
 
        /// <summary>
        /// Simple binary search method for finding smallest index in array where value
        /// meets or exceeds what you're looking for.
        /// </summary>
        /// <param name="a">Array to search</param>
        /// <param name="u">Value to search for</param>
        /// <param name="low">Left boundary of search</param>
        /// <param name="high">Right boundary of search</param>
        /// <returns></returns>
        private int BinarySearch(double[] a, double u, int low, int high)
        {
            int diff = high - low;
            if (diff < 2)
                return a[low] >= u ? low : high;
            int mid = low + (diff / 2);
            return a[mid] >= u ? BinarySearch(a, u, low, mid) : BinarySearch(a, u, mid, high);
        }
 
        public static double[] Normalize(double[] weights)
        {
            double total = Sum(weights);
 
            // If all weights equal zero, set to 1 (to avoid divide by zero).
            if (total <= Double.Epsilon)
            {
                weights.SetValue(1, 0, weights.Length);
                total = weights.Length;
            }
 
            for (int i = 0; i < weights.Length; i++)
                weights[i] /= total;
            return weights;
        }
 
        public static double[] InverseNormalize(double[] weights)
        {
            weights = Normalize(weights);
 
            for (int i = 0; i < weights.Length; i++)
                weights[i] = 1 - weights[i];
 
            return Normalize(weights);
        }
 
        public static float[] ParameterSetAsFloatArray(IHost host, IValueGenerator[] sweepParams, ParameterSet ps, bool expandCategoricals = true)
        {
            host.Assert(ps.Count == sweepParams.Length);
 
            var result = new List<float>();
 
            for (int i = 0; i < sweepParams.Length; i++)
            {
                // This allows us to query possible values of this parameter.
                var sweepParam = sweepParams[i];
 
                // This holds the actual value for this parameter, chosen in this parameter set.
                var pset = ps[sweepParam.Name];
                host.AssertValue(pset);
 
                var parameterDiscrete = sweepParam as DiscreteValueGenerator;
                if (parameterDiscrete != null)
                {
                    int hotIndex = -1;
                    for (int j = 0; j < parameterDiscrete.Count; j++)
                    {
                        if (parameterDiscrete[j].Equals(pset))
                        {
                            hotIndex = j;
                            break;
                        }
                    }
                    host.Assert(hotIndex >= 0);
 
                    if (expandCategoricals)
                        for (int j = 0; j < parameterDiscrete.Count; j++)
                            result.Add(j == hotIndex ? 1 : 0);
                    else
                        result.Add(hotIndex);
                }
                else if (sweepParam is LongValueGenerator lvg)
                {
                    // Normalizing all numeric parameters to [0,1] range.
                    result.Add(lvg.NormalizeValue(new LongParameterValue(pset.Name, long.Parse(pset.ValueText))));
                }
                else if (sweepParam is FloatValueGenerator fvg)
                {
                    // Normalizing all numeric parameters to [0,1] range.
                    result.Add(fvg.NormalizeValue(new FloatParameterValue(pset.Name, float.Parse(pset.ValueText))));
                }
                else
                    throw host.Except("Smart sweeper can only sweep over discrete and numeric parameters");
            }
 
            return result.ToArray();
        }
 
        public static ParameterSet FloatArrayAsParameterSet(IHost host, IValueGenerator[] sweepParams, float[] array, bool expandedCategoricals = true)
        {
            Contracts.Assert(array.Length == sweepParams.Length);
 
            List<IParameterValue> parameters = new List<IParameterValue>();
            int currentArrayIndex = 0;
            for (int i = 0; i < sweepParams.Length; i++)
            {
                var parameterDiscrete = sweepParams[i] as DiscreteValueGenerator;
                if (parameterDiscrete != null)
                {
                    if (expandedCategoricals)
                    {
                        int hotIndex = -1;
                        for (int j = 0; j < parameterDiscrete.Count; j++)
                        {
                            if (array[i + j] > 0)
                            {
                                hotIndex = j;
                                break;
                            }
                        }
                        host.Assert(hotIndex >= i);
                        parameters.Add(new StringParameterValue(sweepParams[i].Name, parameterDiscrete[hotIndex].ValueText));
                        currentArrayIndex += parameterDiscrete.Count;
                    }
                    else
                    {
                        parameters.Add(new StringParameterValue(sweepParams[i].Name, parameterDiscrete[(int)array[currentArrayIndex]].ValueText));
                        currentArrayIndex++;
                    }
                }
                else
                {
                    parameters.Add(sweepParams[i].CreateFromNormalized(array[currentArrayIndex]));
                    currentArrayIndex++;
                }
            }
 
            return new ParameterSet(parameters);
        }
    }
}