File: Sweepers\ISweeper.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Internal.Utilities;
 
namespace Microsoft.ML.AutoML
{
    /// <summary>
    /// The main interface of the sweeper
    /// </summary>
    internal interface ISweeper
    {
        /// <summary>
        /// Returns between 0 and maxSweeps configurations to run.
        /// It expects a list of previous runs such that it can generate configurations that were not already tried.
        /// The list of runs can be null if there were no previous runs.
        /// Some smart sweepers can take advantage of the metric(s) that the caller computes for previous runs.
        /// </summary>
        ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns = null);
    }
 
    /// <summary>
    /// This is the interface that each type of parameter sweep needs to implement
    /// </summary>
    internal interface IValueGenerator
    {
        /// <summary>
        /// Given a value in the [0,1] range, return a value for this parameter.
        /// </summary>
        IParameterValue CreateFromNormalized(Double normalizedValue);
 
        /// <summary>
        /// Used mainly in grid sweepers, return the i-th distinct value for this parameter
        /// </summary>
        IParameterValue this[int i] { get; }
 
        /// <summary>
        /// Used mainly in grid sweepers, return the count of distinct values for this parameter
        /// </summary>
        int Count { get; }
 
        /// <summary>
        /// Returns the name of the generated parameter
        /// </summary>
        string Name { get; }
    }
 
    /// <summary>
    /// Parameter value generated from the sweeping.
    /// The parameter values must be immutable.
    /// Value is converted to string because the runner will usually want to construct a command line for TL.
    /// Implementations of this interface must also override object.GetHashCode() and object.Equals(object) so they are consistent
    /// with IEquatable.Equals(IParameterValue).
    /// </summary>
    internal interface IParameterValue : IEquatable<IParameterValue>
    {
        string Name { get; }
        string ValueText { get; }
    }
 
    /// <summary>
    /// Type safe version of the IParameterValue interface.
    /// </summary>
    internal interface IParameterValue<out TValue> : IParameterValue
    {
        TValue Value { get; }
    }
 
    /// <summary>
    /// A set of parameter values.
    /// The parameter set must be immutable.
    /// </summary>
    internal sealed class ParameterSet : IEquatable<ParameterSet>, IEnumerable<IParameterValue>
    {
        private readonly Dictionary<string, IParameterValue> _parameterValues;
        private readonly int _hash;
 
        public ParameterSet(IEnumerable<IParameterValue> parameters)
        {
            _parameterValues = new Dictionary<string, IParameterValue>();
            foreach (var parameter in parameters)
            {
                _parameterValues.Add(parameter.Name, parameter);
            }
 
            var parameterNames = _parameterValues.Keys.ToList();
            parameterNames.Sort();
            _hash = 0;
            foreach (var parameterName in parameterNames)
            {
                _hash = Hashing.CombineHash(_hash, _parameterValues[parameterName].GetHashCode());
            }
        }
 
        public ParameterSet(Dictionary<string, IParameterValue> paramValues, int hash)
        {
            _parameterValues = paramValues;
            _hash = hash;
        }
 
        public IEnumerator<IParameterValue> GetEnumerator()
        {
            return _parameterValues.Values.GetEnumerator();
        }
 
        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }
 
        public int Count
        {
            get { return _parameterValues.Count; }
        }
 
        public IParameterValue this[string name]
        {
            get { return _parameterValues[name]; }
        }
 
        private bool ContainsParamValue(IParameterValue parameterValue)
        {
            IParameterValue value;
            return _parameterValues.TryGetValue(parameterValue.Name, out value) &&
                   parameterValue.Equals(value);
        }
 
        public bool Equals(ParameterSet other)
        {
            if (other == null || other._hash != _hash || other._parameterValues.Count != _parameterValues.Count)
                return false;
            return other._parameterValues.Values.All(pv => ContainsParamValue(pv));
        }
 
        public ParameterSet Clone() =>
            new ParameterSet(new Dictionary<string, IParameterValue>(_parameterValues), _hash);
 
        public override string ToString()
        {
            return string.Join(" ", _parameterValues.Select(kvp => string.Format("{0}={1}", kvp.Value.Name, kvp.Value.ValueText)).ToArray());
        }
 
        public override int GetHashCode()
        {
            return _hash;
        }
    }
 
    /// <summary>
    /// The result of a run.
    /// Contains the parameter set used, useful for the sweeper to not generate the same configuration multiple times.
    /// Also contains the result of a run and the metric value that is used by smart sweepers to generate new configurations
    /// that try to maximize this metric.
    /// </summary>
    internal interface IRunResult : IComparable<IRunResult>
    {
        ParameterSet ParameterSet { get; }
        IComparable MetricValue { get; }
        bool IsMetricMaximizing { get; }
    }
 
    internal interface IRunResult<T> : IRunResult
        where T : IComparable<T>
    {
        new T MetricValue { get; }
    }
 
    /// <summary>
    /// Simple implementation of IRunResult
    /// </summary>
    internal sealed class RunResult : IRunResult<Double>
    {
        private readonly ParameterSet _parameterSet;
        private readonly Double? _metricValue;
        private readonly bool _isMetricMaximizing;
 
        /// <summary>
        /// This switch changes the behavior of the CompareTo function, switching the greater than / less than
        /// behavior, depending on if it is set to True.
        /// </summary>
        public bool IsMetricMaximizing { get { return _isMetricMaximizing; } }
 
        public ParameterSet ParameterSet
        {
            get { return _parameterSet; }
        }
 
        public RunResult(ParameterSet parameterSet, Double metricValue, bool isMetricMaximizing)
        {
            _parameterSet = parameterSet;
            _metricValue = metricValue;
            _isMetricMaximizing = isMetricMaximizing;
        }
 
        public Double MetricValue
        {
            get
            {
                return _metricValue.Value;
            }
        }
 
        public int CompareTo(IRunResult other)
        {
            var otherTyped = other as RunResult;
            //Contracts.Check(otherTyped != null);
            if (_metricValue == otherTyped._metricValue)
                return 0;
            return _isMetricMaximizing ^ (_metricValue < otherTyped._metricValue) ? 1 : -1;
        }
 
        public bool HasMetricValue
        {
            get
            {
                return _metricValue != null;
            }
        }
 
        IComparable IRunResult.MetricValue
        {
            get { return MetricValue; }
        }
    }
 
    /// <summary>
    /// The metric class, used by smart sweeping algorithms.
    /// Ideally we would like to move towards a IDataView, this is
    /// just a simple view instead, and it is decoupled from RunResult so we can move
    /// in that direction in the future.
    /// </summary>
    internal sealed class RunMetric
    {
        private readonly float _primaryMetric;
        private readonly float[] _metricDistribution;
 
        public RunMetric(float primaryMetric, IEnumerable<float> metricDistribution = null)
        {
            _primaryMetric = primaryMetric;
            if (metricDistribution != null)
                _metricDistribution = metricDistribution.ToArray();
        }
 
        /// <summary>
        /// The primary metric to optimize.
        /// This metric is usually an aggregate value for the run, for example, AUC, accuracy etc.
        /// By default, smart sweeping algorithms will maximize this metric.
        /// If you want to minimize, either negate this value or change the option in the arguments of the sweeper constructor.
        /// </summary>
        public float PrimaryMetric
        {
            get { return _primaryMetric; }
        }
 
        /// <summary>
        /// The (optional) distribution of the metric.
        /// This distribution can be a secondary measure of how good a run was, e.g per-fold AUC, per-fold accuracy, (sampled) per-instance log loss etc.
        /// </summary>
        public float[] GetMetricDistribution()
        {
            if (_metricDistribution == null)
                return null;
            var result = new float[_metricDistribution.Length];
            Array.Copy(_metricDistribution, result, _metricDistribution.Length);
            return result;
        }
    }
}