File: Utils\SweepableParamAttributes.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Globalization;
using System.Linq;
using System.Text;
 
namespace Microsoft.ML.AutoML
{
    /// <summary>
    /// Used to indicate suggested sweep ranges for parameter sweeping.
    /// </summary>
    internal abstract class SweepableParam
    {
        public string Name { get; set; }
        private IComparable _rawValue;
        public virtual IComparable RawValue
        {
            get => _rawValue;
            set
            {
                if (!Frozen)
                    _rawValue = value;
            }
        }
 
        // The raw value will store an index for discrete parameters,
        // but sometimes we want the text or numeric value itself,
        // not the hot index. The processed value does that for discrete
        // params. For other params, it just returns the raw value itself.
        public virtual IComparable ProcessedValue() => _rawValue;
 
        // Allows for hyperparameter value freezing, so that sweeps
        // will not alter the current value when true.
        public bool Frozen { get; set; }
 
        // Allows the sweepable param to be set directly using the
        // available ValueText attribute on IParameterValues (from
        // the ParameterSets used in the old hyperparameter sweepers).
        public abstract void SetUsingValueText(string valueText);
 
        public abstract SweepableParam Clone();
    }
 
    internal sealed class SweepableDiscreteParam : SweepableParam
    {
        public object[] Options { get; }
 
        public SweepableDiscreteParam(string name, object[] values, bool isBool = false) : this(values, isBool)
        {
            Name = name;
        }
 
        public SweepableDiscreteParam(object[] values, bool isBool = false)
        {
            Options = isBool ? new object[] { false, true } : values;
        }
 
        public override IComparable RawValue
        {
            get => base.RawValue;
            set
            {
                var val = Convert.ToInt32(value);
                if (!Frozen && 0 <= val && val < Options.Length)
                    base.RawValue = val;
            }
        }
 
        public override void SetUsingValueText(string valueText)
        {
            for (int i = 0; i < Options.Length; i++)
                if (valueText == Options[i].ToString())
                    RawValue = i;
        }
 
        private static string TranslateOption(object o)
        {
            switch (o)
            {
                case float _:
                case double _:
                    return $"{o}f";
                case long _:
                case int _:
                case byte _:
                case short _:
                    return o.ToString();
                case bool _:
                    return o.ToString().ToLower();
                case Enum _:
                    var type = o.GetType();
                    var defaultName = $"Enums.{type.Name}.{o.ToString()}";
                    var name = type.FullName?.Replace("+", ".");
                    if (name == null)
                        return defaultName;
                    var index1 = name.LastIndexOf(".", StringComparison.Ordinal);
                    var index2 = name.Substring(0, index1).LastIndexOf(".", StringComparison.Ordinal) + 1;
                    if (index2 >= 0)
                        return $"{name.Substring(index2)}.{o.ToString()}";
                    return defaultName;
                default:
                    return $"\"{o}\"";
            }
        }
 
        public override SweepableParam Clone() =>
            new SweepableDiscreteParam(Name, Options) { RawValue = RawValue, Frozen = Frozen };
 
        public override string ToString()
        {
            var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", ";
            return $"[{GetType().Name}({name}new object[]{{{string.Join(", ", Options.Select(TranslateOption))}}})]";
        }
 
        public override IComparable ProcessedValue() => (IComparable)Options[(int)RawValue];
    }
 
    internal sealed class SweepableFloatParam : SweepableParam
    {
        public float Min { get; }
        public float Max { get; }
        public float? StepSize { get; }
        public int? NumSteps { get; }
        public bool IsLogScale { get; }
 
        public SweepableFloatParam(string name, float min, float max, float stepSize = -1, int numSteps = -1,
            bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale)
        {
            Name = name;
        }
 
        public SweepableFloatParam(float min, float max, float stepSize = -1, int numSteps = -1, bool isLogScale = false)
        {
            Min = min;
            Max = max;
            if (!stepSize.Equals(-1))
                StepSize = stepSize;
            if (numSteps != -1)
                NumSteps = numSteps;
            IsLogScale = isLogScale;
        }
 
        public override void SetUsingValueText(string valueText)
        {
            RawValue = float.Parse(valueText, CultureInfo.InvariantCulture);
        }
 
        public override SweepableParam Clone() =>
            new SweepableFloatParam(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen };
 
        public override string ToString()
        {
            var optional = new StringBuilder();
            if (StepSize != null)
                optional.Append($", stepSize:{StepSize}");
            if (NumSteps != null)
                optional.Append($", numSteps:{NumSteps}");
            if (IsLogScale)
                optional.Append($", isLogScale:true");
            var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", ";
            return $"[{GetType().Name}({name}{Min}f, {Max}f{optional})]";
        }
    }
 
    internal sealed class SweepableLongParam : SweepableParam
    {
        public long Min { get; }
        public long Max { get; }
        public float? StepSize { get; }
        public int? NumSteps { get; }
        public bool IsLogScale { get; }
 
        public SweepableLongParam(string name, long min, long max, float stepSize = -1, int numSteps = -1,
            bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale)
        {
            Name = name;
        }
 
        public SweepableLongParam(long min, long max, float stepSize = -1, int numSteps = -1, bool isLogScale = false)
        {
            Min = min;
            Max = max;
            if (!stepSize.Equals(-1))
                StepSize = stepSize;
            if (numSteps != -1)
                NumSteps = numSteps;
            IsLogScale = isLogScale;
        }
 
        public override void SetUsingValueText(string valueText)
        {
            RawValue = long.Parse(valueText);
        }
 
        public override SweepableParam Clone() =>
            new SweepableLongParam(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen };
 
        public override string ToString()
        {
            var optional = new StringBuilder();
            if (StepSize != null)
                optional.Append($", stepSize:{StepSize}");
            if (NumSteps != null)
                optional.Append($", numSteps:{NumSteps}");
            if (IsLogScale)
                optional.Append($", isLogScale:true");
            var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", ";
            return $"[{GetType().Name}({name}{Min}, {Max}{optional})]";
        }
    }
}