File: Option\ChoiceOption.cs
Web Access
Project: src\src\Microsoft.ML.SearchSpace\Microsoft.ML.SearchSpace.csproj (Microsoft.ML.SearchSpace)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Diagnostics.Contracts;
using System.Linq;
using System.Text.Json.Serialization;
using Microsoft.ML.SearchSpace.Converter;
 
#nullable enable
 
namespace Microsoft.ML.SearchSpace.Option
{
    /// <summary>
    /// This class represent option for discrete value, such as string, enum, etc..
    /// </summary>
    [JsonConverter(typeof(ChoiceOptionConverter))]
    public sealed class ChoiceOption : OptionBase
    {
        private readonly UniformSingleOption _option;
 
        /// <summary>
        /// Create <see cref="ChoiceOption"/> with <paramref name="choices"/>
        /// </summary>
        public ChoiceOption(params object[] choices)
        {
            Contract.Assert(choices.Length > 0 && choices.Length < 1074, "the length of choices must be (0, 1074)");
            var distinctChoices = choices.Distinct();
            Contract.Assert(distinctChoices.Count() == choices.Length, "choices must not contain repeated values");
 
            Choices = distinctChoices.Select(o => Parameter.FromObject(o)).ToArray();
            _option = new UniformSingleOption(0, Choices.Length);
            Default = Enumerable.Repeat(0.0, FeatureSpaceDim).ToArray();
        }
 
        /// <summary>
        /// Create <see cref="ChoiceOption"/> with <paramref name="choices"/> and <paramref name="defaultChoice"/>.
        /// </summary>
        public ChoiceOption(object[] choices, object? defaultChoice)
            : this(choices)
        {
            if (defaultChoice != null)
            {
                Default = MappingToFeatureSpace(Parameter.FromObject(defaultChoice));
            }
        }
 
        /// <summary>
        /// Get all choices.
        /// </summary>
        public Parameter[] Choices { get; }
 
        /// <inheritdoc/>
        public override int FeatureSpaceDim => Choices.Length == 1 ? 0 : 1;
 
        /// <inheritdoc/>
        public override int?[] Step => new int?[] { Choices.Length };
 
        /// <inheritdoc/>
        public override double[] MappingToFeatureSpace(Parameter param)
        {
            if (FeatureSpaceDim == 0)
            {
                return new double[0];
            }
 
            var x = Array.IndexOf(Choices, param);
            Contract.Assert(x >= 0, $"{param} not contains");
 
            return _option.MappingToFeatureSpace(Parameter.FromInt(x));
        }
 
        /// <inheritdoc/>
        public override Parameter SampleFromFeatureSpace(double[] values)
        {
            Contract.Assert(values.Length >= 0, "values length must be greater than 0");
            if (values.Length == 0)
            {
                return Choices[0];
            }
 
            var param = _option.SampleFromFeatureSpace(values);
            var value = param.AsType<float>();
            var idx = Convert.ToInt32(Math.Floor(value));
 
            // idx will be equal to choices.length if value is [1]
            // so we need to handle special case here.
            if (idx >= Choices.Length)
            {
                idx = Choices.Length - 1;
            }
            return Choices[idx];
        }
    }
}