|  | 
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Sweeper;
 
[assembly: LoadableClass(typeof(LongValueGenerator), typeof(LongParamOptions), typeof(SignatureSweeperParameter),
    "Long parameter", "lp")]
[assembly: LoadableClass(typeof(FloatValueGenerator), typeof(FloatParamOptions), typeof(SignatureSweeperParameter),
    "Float parameter", "fp")]
[assembly: LoadableClass(typeof(DiscreteValueGenerator), typeof(DiscreteParamOptions), typeof(SignatureSweeperParameter),
    "Discrete parameter", "dp")]
 
namespace Microsoft.ML.Sweeper
{
    public delegate void SignatureSweeperParameter();
 
    public abstract class BaseParamOptions
    {
        [Argument(ArgumentType.Required, HelpText = "Parameter name", ShortName = "n")]
        public string Name;
    }
 
    public abstract class NumericParamOptions : BaseParamOptions
    {
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Number of steps for grid runthrough.", ShortName = "steps")]
        public int NumSteps = 100;
 
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Amount of increment between steps (multiplicative if log).", ShortName = "inc")]
        public Double? StepSize = null;
 
        [Argument(ArgumentType.LastOccurrenceWins, HelpText = "Log scale.", ShortName = "log")]
        public bool LogBase = false;
    }
 
    public class FloatParamOptions : NumericParamOptions
    {
        [Argument(ArgumentType.Required, HelpText = "Minimum value")]
        public float Min;
 
        [Argument(ArgumentType.Required, HelpText = "Maximum value")]
        public float Max;
    }
 
    public class LongParamOptions : NumericParamOptions
    {
        [Argument(ArgumentType.Required, HelpText = "Minimum value")]
        public long Min;
 
        [Argument(ArgumentType.Required, HelpText = "Maximum value")]
        public long Max;
    }
 
    public class DiscreteParamOptions : BaseParamOptions
    {
        [Argument(ArgumentType.Multiple, HelpText = "Values", ShortName = "v")]
        public string[] Values = null;
    }
 
    public sealed class LongParameterValue : IParameterValue<long>
    {
        private readonly string _name;
        private readonly string _valueText;
        private readonly long _value;
 
        public string Name
        {
            get { return _name; }
        }
 
        public string ValueText
        {
            get { return _valueText; }
        }
 
        public long Value
        {
            get { return _value; }
        }
 
        public LongParameterValue(string name, long value)
        {
            _name = name;
            _value = value;
            _valueText = _value.ToString("D");
        }
 
        public bool Equals(IParameterValue other)
        {
            return Equals((object)other);
        }
 
        public override bool Equals(object obj)
        {
            var lpv = obj as LongParameterValue;
            return lpv != null && Name == lpv.Name && _value == lpv._value;
        }
 
        public override int GetHashCode()
        {
            return Hashing.CombinedHash(0, typeof(LongParameterValue), _name, _value);
        }
    }
 
    public sealed class FloatParameterValue : IParameterValue<float>
    {
        private readonly string _name;
        private readonly string _valueText;
        private readonly float _value;
 
        public string Name
        {
            get { return _name; }
        }
 
        public string ValueText
        {
            get { return _valueText; }
        }
 
        public float Value
        {
            get { return _value; }
        }
 
        public FloatParameterValue(string name, float value)
        {
            Contracts.Check(!float.IsNaN(value));
            _name = name;
            _value = value;
            _valueText = _value.ToString("R");
        }
 
        public bool Equals(IParameterValue other)
        {
            return Equals((object)other);
        }
 
        public override bool Equals(object obj)
        {
            var fpv = obj as FloatParameterValue;
            return fpv != null && Name == fpv.Name && _value == fpv._value;
        }
 
        public override int GetHashCode()
        {
            return Hashing.CombinedHash(0, typeof(FloatParameterValue), _name, _value);
        }
    }
 
    public sealed class StringParameterValue : IParameterValue<string>
    {
        private readonly string _name;
        private readonly string _value;
 
        public string Name
        {
            get { return _name; }
        }
 
        public string ValueText
        {
            get { return _value; }
        }
 
        public string Value
        {
            get { return _value; }
        }
 
        public StringParameterValue(string name, string value)
        {
            _name = name;
            _value = value;
        }
 
        public bool Equals(IParameterValue other)
        {
            return Equals((object)other);
        }
 
        public override bool Equals(object obj)
        {
            var spv = obj as StringParameterValue;
            return spv != null && Name == spv.Name && ValueText == spv.ValueText;
        }
 
        public override int GetHashCode()
        {
            return Hashing.CombinedHash(0, typeof(StringParameterValue), _name, _value);
        }
    }
 
    public interface INumericValueGenerator : IValueGenerator
    {
        float NormalizeValue(IParameterValue value);
        bool InRange(IParameterValue value);
    }
 
    /// <summary>
    /// The integer type parameter sweep.
    /// </summary>
    public class LongValueGenerator : INumericValueGenerator
    {
        private readonly LongParamOptions _options;
        private IParameterValue[] _gridValues;
 
        public string Name { get { return _options.Name; } }
 
        public LongValueGenerator(LongParamOptions options)
        {
            Contracts.Check(options.Min < options.Max, "min must be less than max");
            // REVIEW: this condition can be relaxed if we change the math below to deal with it
            Contracts.Check(!options.LogBase || options.Min > 0, "min must be positive if log scale is used");
            Contracts.Check(!options.LogBase || options.StepSize == null || options.StepSize > 1, "StepSize must be greater than 1 if log scale is used");
            Contracts.Check(options.LogBase || options.StepSize == null || options.StepSize > 0, "StepSize must be greater than 0 if linear scale is used");
            _options = options;
        }
 
        // REVIEW: Is float accurate enough?
        public IParameterValue CreateFromNormalized(Double normalizedValue)
        {
            long val;
            if (_options.LogBase)
            {
                // REVIEW: review the math below, it only works for positive Min and Max
                var logBase = !_options.StepSize.HasValue
                    ? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1))
                    : _options.StepSize.Value;
                var logMax = Math.Log(_options.Max, logBase);
                var logMin = Math.Log(_options.Min, logBase);
                val = (long)(_options.Min * Math.Pow(logBase, normalizedValue * (logMax - logMin)));
            }
            else
                val = (long)(_options.Min + normalizedValue * (_options.Max - _options.Min));
 
            return new LongParameterValue(_options.Name, val);
        }
 
        private void EnsureParameterValues()
        {
            if (_gridValues != null)
                return;
 
            var result = new List<IParameterValue>();
            if ((_options.StepSize == null && _options.NumSteps > (_options.Max - _options.Min)) ||
                (_options.StepSize != null && _options.StepSize <= 1))
            {
                for (long i = _options.Min; i <= _options.Max; i++)
                    result.Add(new LongParameterValue(_options.Name, i));
            }
            else
            {
                if (_options.LogBase)
                {
                    // REVIEW: review the math below, it only works for positive Min and Max
                    var logBase = _options.StepSize ?? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1));
 
                    long prevValue = long.MinValue;
                    var maxPlusEpsilon = _options.Max * Math.Sqrt(logBase);
                    for (Double value = _options.Min; value <= maxPlusEpsilon; value *= logBase)
                    {
                        var longValue = (long)value;
                        if (longValue > prevValue)
                            result.Add(new LongParameterValue(_options.Name, longValue));
                        prevValue = longValue;
                    }
                }
                else
                {
                    var stepSize = _options.StepSize ?? (Double)(_options.Max - _options.Min) / (_options.NumSteps - 1);
                    long prevValue = long.MinValue;
                    var maxPlusEpsilon = _options.Max + stepSize / 2;
                    for (Double value = _options.Min; value <= maxPlusEpsilon; value += stepSize)
                    {
                        var longValue = (long)value;
                        if (longValue > prevValue)
                            result.Add(new LongParameterValue(_options.Name, longValue));
                        prevValue = longValue;
                    }
                }
            }
            _gridValues = result.ToArray();
        }
 
        public IParameterValue this[int i]
        {
            get
            {
                EnsureParameterValues();
                return _gridValues[i];
            }
        }
 
        public int Count
        {
            get
            {
                EnsureParameterValues();
                return _gridValues.Length;
            }
        }
 
        public float NormalizeValue(IParameterValue value)
        {
            var valueTyped = value as LongParameterValue;
            Contracts.Check(valueTyped != null, "LongValueGenerator could not normalized parameter because it is not of the correct type");
            Contracts.Check(_options.Min <= valueTyped.Value && valueTyped.Value <= _options.Max, "Value not in correct range");
 
            if (_options.LogBase)
            {
                float logBase = (float)(_options.StepSize ?? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1)));
                return (float)((Math.Log(valueTyped.Value, logBase) - Math.Log(_options.Min, logBase)) / (Math.Log(_options.Max, logBase) - Math.Log(_options.Min, logBase)));
            }
            else
                return (float)(valueTyped.Value - _options.Min) / (_options.Max - _options.Min);
        }
 
        public bool InRange(IParameterValue value)
        {
            var valueTyped = value as LongParameterValue;
            Contracts.Check(valueTyped != null, "Parameter should be of type LongParameterValue");
            return (_options.Min <= valueTyped.Value && valueTyped.Value <= _options.Max);
        }
 
        public string ToStringParameter(IHostEnvironment env)
        {
            return $" p=lp{{{CmdParser.GetSettings(env, _options, new LongParamOptions())}}}";
        }
    }
 
    /// <summary>
    /// The floating point type parameter sweep.
    /// </summary>
    public class FloatValueGenerator : INumericValueGenerator
    {
        private readonly FloatParamOptions _options;
        private IParameterValue[] _gridValues;
 
        public string Name { get { return _options.Name; } }
 
        public FloatValueGenerator(FloatParamOptions options)
        {
            Contracts.Check(options.Min < options.Max, "min must be less than max");
            // REVIEW: this condition can be relaxed if we change the math below to deal with it
            Contracts.Check(!options.LogBase || options.Min > 0, "min must be positive if log scale is used");
            Contracts.Check(!options.LogBase || options.StepSize == null || options.StepSize > 1, "StepSize must be greater than 1 if log scale is used");
            Contracts.Check(options.LogBase || options.StepSize == null || options.StepSize > 0, "StepSize must be greater than 0 if linear scale is used");
            _options = options;
        }
 
        // REVIEW: Is float accurate enough?
        public IParameterValue CreateFromNormalized(Double normalizedValue)
        {
            float val;
            if (_options.LogBase)
            {
                // REVIEW: review the math below, it only works for positive Min and Max
                var logBase = !_options.StepSize.HasValue
                    ? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1))
                    : _options.StepSize.Value;
                var logMax = Math.Log(_options.Max, logBase);
                var logMin = Math.Log(_options.Min, logBase);
                val = (float)(_options.Min * Math.Pow(logBase, normalizedValue * (logMax - logMin)));
            }
            else
                val = (float)(_options.Min + normalizedValue * (_options.Max - _options.Min));
 
            return new FloatParameterValue(_options.Name, val);
        }
 
        private void EnsureParameterValues()
        {
            if (_gridValues != null)
                return;
 
            var result = new List<IParameterValue>();
            if (_options.LogBase)
            {
                // REVIEW: review the math below, it only works for positive Min and Max
                var logBase = _options.StepSize ?? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1));
 
                float prevValue = float.NegativeInfinity;
                var maxPlusEpsilon = _options.Max * Math.Sqrt(logBase);
                for (Double value = _options.Min; value <= maxPlusEpsilon; value *= logBase)
                {
                    var floatValue = (float)value;
                    if (floatValue > prevValue)
                        result.Add(new FloatParameterValue(_options.Name, floatValue));
                    prevValue = floatValue;
                }
            }
            else
            {
                var stepSize = _options.StepSize ?? (Double)(_options.Max - _options.Min) / (_options.NumSteps - 1);
                float prevValue = float.NegativeInfinity;
                var maxPlusEpsilon = _options.Max + stepSize / 2;
                for (Double value = _options.Min; value <= maxPlusEpsilon; value += stepSize)
                {
                    var floatValue = (float)value;
                    if (floatValue > prevValue)
                        result.Add(new FloatParameterValue(_options.Name, floatValue));
                    prevValue = floatValue;
                }
            }
 
            _gridValues = result.ToArray();
        }
 
        public IParameterValue this[int i]
        {
            get
            {
                EnsureParameterValues();
                return _gridValues[i];
            }
        }
 
        public int Count
        {
            get
            {
                EnsureParameterValues();
                return _gridValues.Length;
            }
        }
 
        public float NormalizeValue(IParameterValue value)
        {
            var valueTyped = value as FloatParameterValue;
            Contracts.Check(valueTyped != null, "FloatValueGenerator could not normalized parameter because it is not of the correct type");
            Contracts.Check(_options.Min <= valueTyped.Value && valueTyped.Value <= _options.Max, "Value not in correct range");
 
            if (_options.LogBase)
            {
                float logBase = (float)(_options.StepSize ?? Math.Pow(1.0 * _options.Max / _options.Min, 1.0 / (_options.NumSteps - 1)));
                return (float)((Math.Log(valueTyped.Value, logBase) - Math.Log(_options.Min, logBase)) / (Math.Log(_options.Max, logBase) - Math.Log(_options.Min, logBase)));
            }
            else
                return (valueTyped.Value - _options.Min) / (_options.Max - _options.Min);
        }
 
        public bool InRange(IParameterValue value)
        {
            var valueTyped = value as FloatParameterValue;
            Contracts.Check(valueTyped != null, "Parameter should be of type FloatParameterValue");
            return (_options.Min <= valueTyped.Value && valueTyped.Value <= _options.Max);
        }
 
        public string ToStringParameter(IHostEnvironment env)
        {
            return $" p=fp{{{CmdParser.GetSettings(env, _options, new FloatParamOptions())}}}";
        }
    }
 
    /// <summary>
    /// The discrete parameter sweep.
    /// </summary>
    public class DiscreteValueGenerator : IValueGenerator
    {
        private readonly DiscreteParamOptions _options;
 
        public string Name { get { return _options.Name; } }
 
        public DiscreteValueGenerator(DiscreteParamOptions options)
        {
            Contracts.Check(options.Values.Length > 0);
            _options = options;
        }
 
        // REVIEW: Is float accurate enough?
        public IParameterValue CreateFromNormalized(Double normalizedValue)
        {
            return new StringParameterValue(_options.Name, _options.Values[(int)(_options.Values.Length * normalizedValue)]);
        }
 
        public IParameterValue this[int i]
        {
            get
            {
                return new StringParameterValue(_options.Name, _options.Values[i]);
            }
        }
 
        public int Count
        {
            get
            {
                return _options.Values.Length;
            }
        }
 
        public string ToStringParameter(IHostEnvironment env)
        {
            return $" p=dp{{{CmdParser.GetSettings(env, _options, new DiscreteParamOptions())}}}";
        }
    }
 
    public sealed class SuggestedSweepsParser
    {
        /// <summary>
        /// Generic parameter parser. Currently hand-hacked to auto-detect type.
        ///
        /// Generic form:   Name:Values
        /// for example,    lr:0.05-0.4
        ///          lambda:0.1-1000@log10
        ///          nl:2-64@log2
        ///          norm:-,+
        /// </summary>
        /// REVIEW: allow overriding auto-detection to specify type
        /// and delegate to parameter type for actual parsing
        /// REVIEW: specifying ordinal discrete parameters
        public bool TryParseParameter(string paramValue, Type paramType, string paramName, out IValueGenerator sweepValues, out string error)
        {
            sweepValues = null;
            error = null;
 
            if (paramValue.Contains(','))
            {
                var generatorOptions = new DiscreteParamOptions();
                generatorOptions.Name = paramName;
                generatorOptions.Values = paramValue.Split(',');
                sweepValues = new DiscreteValueGenerator(generatorOptions);
                return true;
            }
 
            // numeric parameter
            if (!CmdParser.IsNumericType(paramType))
                return false;
 
            // REVIEW:  deal with negative bounds
            string scaleStr = null;
            int atIdx = paramValue.IndexOf('@');
            if (atIdx < 0)
                atIdx = paramValue.IndexOf(';');
            if (atIdx >= 0)
            {
                scaleStr = paramValue.Substring(atIdx + 1);
                paramValue = paramValue.Substring(0, atIdx);
                if (scaleStr.Length < 1)
                {
                    error = $"Could not parse sweep range for parameter: {paramName}";
                    return false;
                }
            }
 
            // Extract the minimum, and the maximum value of the list of suggested sweeps.
            // Positive lookahead splitting at the '-' character.
            // It is used for the Float and Long param types.
            // Example format: "0.02-0.1;steps:5".
            string[] minMaxRegex = Regex.Split(paramValue, "(?<=[^eE])-");
            if (minMaxRegex.Length != 2)
            {
                if (minMaxRegex.Length > 2)
                    error = $"Could not parse sweep range for parameter: {paramName}";
 
                return false;
            }
            string minStr = minMaxRegex[0];
            string maxStr = minMaxRegex[1];
 
            int numSteps = 100;
            Double stepSize = -1;
            bool logBase = false;
            if (scaleStr != null)
            {
                try
                {
                    string[] options = scaleStr.Split(';');
                    bool[] optionsSpecified = new bool[3];
                    foreach (string option in options)
                    {
                        if (option.StartsWith("log") && !option.StartsWith("log-") && !option.StartsWith("log:-"))
                        {
                            logBase = true;
                            optionsSpecified[0] = true;
                        }
                        if (option.StartsWith("steps"))
                        {
                            numSteps = int.Parse(option.Substring(option.IndexOf(':') + 1));
                            optionsSpecified[1] = true;
                        }
                        if (option.StartsWith("inc"))
                        {
                            stepSize = Double.Parse(option.Substring(option.IndexOf(':') + 1), CultureInfo.InvariantCulture);
                            optionsSpecified[2] = true;
                        }
                    }
                    if (options.Length != optionsSpecified.Count(b => b))
                    {
                        error = $"Could not parse sweep range for parameter: {paramName}";
                        return false;
                    }
                }
                catch (Exception e)
                {
                    error = $"Error creating sweep generator for parameter '{paramName}': {e.Message}";
                    return false;
                }
            }
 
            if (paramType == typeof(UInt16)
                || paramType == typeof(UInt32)
                || paramType == typeof(UInt64)
                || paramType == typeof(short)
                || paramType == typeof(int)
                || paramType == typeof(long))
            {
                long min;
                long max;
                if (!long.TryParse(minStr, out min) || !long.TryParse(maxStr, out max))
                    return false;
                var generatorOptions = new Microsoft.ML.Sweeper.LongParamOptions();
                generatorOptions.Name = paramName;
                generatorOptions.Min = min;
                generatorOptions.Max = max;
                generatorOptions.NumSteps = numSteps;
                generatorOptions.StepSize = (stepSize > 0 ? stepSize : new Nullable<Double>());
                generatorOptions.LogBase = logBase;
 
                try
                {
                    sweepValues = new LongValueGenerator(generatorOptions);
                }
                catch (Exception e)
                {
                    error = $"Error creating sweep generator for parameter '{paramName}': {e.Message}";
                    return false;
                }
            }
            else
            {
                float minF;
                float maxF;
                if (!float.TryParse(minStr, out minF) || !float.TryParse(maxStr, out maxF))
                    return false;
                var floatOptions = new FloatParamOptions();
                floatOptions.Name = paramName;
                floatOptions.Min = minF;
                floatOptions.Max = maxF;
                floatOptions.NumSteps = numSteps;
                floatOptions.StepSize = (stepSize > 0 ? stepSize : new Nullable<Double>());
                floatOptions.LogBase = logBase;
 
                try
                {
                    sweepValues = new FloatValueGenerator(floatOptions);
                }
                catch (Exception e)
                {
                    error = $"Error creating sweep generator for parameter '{paramName}': {e.Message}";
                    return false;
                }
            }
            return true;
        }
    }
}
 |