File: TestSweeper.cs
Web Access
Project: src\test\Microsoft.ML.Sweeper.Tests\Microsoft.ML.Sweeper.Tests.csproj (Microsoft.ML.Sweeper.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.RunTests;
using Microsoft.ML.Runtime;
using Microsoft.ML.Sweeper.Algorithms;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.Sweeper.RunTests
{
    public sealed class TestSweeper : BaseTestBaseline
    {
        public TestSweeper(ITestOutputHelper output) : base(output)
        {
        }
 
        [Theory]
        [InlineData("bla", 10, 20, false, 10, 1000, "15")]
        [InlineData("bla", 10, 1000, true, 10, 1000, "99")]
        [InlineData("bla", 10, 1000, true, 10, 3, "99")]
        public void TestLongValueSweep(string name, int min, int max, bool logBase, int stepSize, int numSteps, string valueText)
        {
            var paramSweep = new LongValueGenerator(new LongParamOptions() { Name = name, Min = min, Max = max, LogBase = logBase, StepSize = stepSize, NumSteps = numSteps });
            IParameterValue value = paramSweep.CreateFromNormalized(0.5);
            Assert.Equal(name, value.Name);
            Assert.Equal(valueText, value.ValueText);
        }
 
        [Fact]
        public void TestLongValueGeneratorRoundTrip()
        {
            var paramSweep = new LongValueGenerator(new LongParamOptions() { Name = "bla", Min = 0, Max = 17 });
            var value = new LongParameterValue("bla", 5);
            float normalizedValue = paramSweep.NormalizeValue(value);
            IParameterValue unNormalizedValue = paramSweep.CreateFromNormalized(normalizedValue);
            Assert.Equal("5", unNormalizedValue.ValueText);
 
            IParameterValue param = paramSweep.CreateFromNormalized(0.345);
            float unNormalizedParam = paramSweep.NormalizeValue(param);
            Assert.Equal("5", param.ValueText);
            Assert.Equal((float)5 / 17, unNormalizedParam);
        }
 
        [Theory]
        [InlineData("bla", 10, 20, false, 10, 1000, "15")]
        [InlineData("bla", 10, 1000, true, 10, 1000, "100")]
        [InlineData("bla", 10, 1000, true, 10, 3, "100")]
        public void TestFloatValueSweep(string name, int min, int max, bool logBase, int stepSize, int numSteps, string valueText)
        {
            var paramSweep = new FloatValueGenerator(new FloatParamOptions() { Name = name, Min = min, Max = max, LogBase = logBase, StepSize = stepSize, NumSteps = numSteps });
            IParameterValue value = paramSweep.CreateFromNormalized(0.5);
            Assert.Equal(name, value.Name);
            Assert.Equal(valueText, value.ValueText);
        }
 
        [Fact]
        public void TestFloatValueGeneratorRoundTrip()
        {
            var paramSweep = new FloatValueGenerator(new FloatParamOptions() { Name = "bla", Min = 1, Max = 5 });
            var random = new Random(123);
            var normalizedValue = (float)random.NextDouble();
            var value = (FloatParameterValue)paramSweep.CreateFromNormalized(normalizedValue);
            var originalNormalizedValue = paramSweep.NormalizeValue(value);
            Assert.Equal(normalizedValue, originalNormalizedValue);
 
            var originalValue = (FloatParameterValue)paramSweep.CreateFromNormalized(normalizedValue);
            Assert.Equal(originalValue.Value, value.Value);
        }
 
        [Theory]
        [InlineData(0.5, "bar")]
        [InlineData(0.75, "baz")]
        public void TestDiscreteValueSweep(double normalizedValue, string expected)
        {
            var paramSweep = new DiscreteValueGenerator(new DiscreteParamOptions() { Name = "bla", Values = new[] { "foo", "bar", "baz" } });
            var value = paramSweep.CreateFromNormalized(normalizedValue);
            Assert.Equal("bla", value.Name);
            Assert.Equal(expected, value.ValueText);
        }
 
        [Fact]
        public void TestRandomSweeper()
        {
            var env = new MLContext(42);
            var args = new SweeperBase.OptionsBase()
            {
                SweptParameters = new[] {
                        ComponentFactoryUtils.CreateFromFunction(
                            environ => new LongValueGenerator(new LongParamOptions() { Name = "foo", Min = 10, Max = 20 })),
                        ComponentFactoryUtils.CreateFromFunction(
                            environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 100, Max = 200 }))
                    }
            };
 
            var sweeper = new UniformRandomSweeper(env, args);
            var initialList = sweeper.ProposeSweeps(5, new List<RunResult>());
            Assert.Equal(5, initialList.Length);
            foreach (var parameterSet in initialList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.InRange(val, 10, 20);
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.InRange(val, 100, 200);
                    }
                    else
                    {
                        Assert.Fail("Wrong parameter");
                    }
                }
            }
        }
 
        [Fact]
        public async Task TestSimpleSweeperAsync()
        {
            var random = new Random(42);
            var env = new MLContext(42);
            const int sweeps = 100;
            var sweeper = new SimpleAsyncSweeper(env, new SweeperBase.OptionsBase
            {
                SweptParameters = new IComponentFactory<IValueGenerator>[] {
                        ComponentFactoryUtils.CreateFromFunction(
                            environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5 })),
                        ComponentFactoryUtils.CreateFromFunction(
                            environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
                    }
            });
 
            var paramSets = new List<ParameterSet>();
            for (int i = 0; i < sweeps; i++)
            {
                var task = sweeper.ProposeAsync();
                var tResult = await task;
                Assert.True(task.IsCompleted);
                paramSets.Add(tResult.ParameterSet);
                var result = new RunResult(tResult.ParameterSet, random.NextDouble(), true);
                sweeper.Update(tResult.Id, result);
            }
            Assert.Equal(sweeps, paramSets.Count);
            CheckAsyncSweeperResult(paramSets);
 
            // Test consumption without ever calling Update.
            var gridArgs = new RandomGridSweeper.Options();
            gridArgs.SweptParameters = new IComponentFactory<INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})),
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 100, LogBase = true }))
                };
            var gridSweeper = new SimpleAsyncSweeper(env, gridArgs);
            paramSets.Clear();
            for (int i = 0; i < sweeps; i++)
            {
                var task = gridSweeper.ProposeAsync();
                var tResult = await task;
                Assert.True(task.IsCompleted);
                paramSets.Add(tResult.ParameterSet);
            }
            Assert.Equal(sweeps, paramSets.Count);
            CheckAsyncSweeperResult(paramSets);
        }
 
        [Fact]
        public async Task TestDeterministicSweeperAsyncCancellation()
        {
            var random = new Random(42);
            var env = new MLContext(42);
            var args = new DeterministicSweeperAsync.Options();
            args.BatchSize = 5;
            args.Relaxation = 1;
 
            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ => new KdoSweeper(environ,
                    new KdoSweeper.Options()
                    {
                        SweptParameters = new IComponentFactory<INumericValueGenerator>[] {
                                ComponentFactoryUtils.CreateFromFunction(
                                    t => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})),
                                ComponentFactoryUtils.CreateFromFunction(
                                    t => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
                        }
                    }));
 
            var sweeper = new DeterministicSweeperAsync(env, args);
 
            int sweeps = 20;
            var tasks = new List<Task<ParameterSetWithId>>();
            int numCompleted = 0;
            for (int i = 0; i < sweeps; i++)
            {
                var task = sweeper.ProposeAsync();
                if (i < args.BatchSize - args.Relaxation)
                {
                    Assert.True(task.IsCompleted);
                    sweeper.Update(task.CompletedResult().Id, new RunResult(task.CompletedResult().ParameterSet, random.NextDouble(), true));
                    numCompleted++;
                }
                else
                    tasks.Add(task);
            }
            // Cancel after the first barrier and check if the number of registered actions
            // is indeed 2 * batchSize.
            sweeper.Cancel();
            await Task.WhenAll(tasks);
            foreach (var task in tasks)
            {
                if (task.CompletedResult() != null)
                    numCompleted++;
            }
            Assert.Equal(args.BatchSize + args.BatchSize, numCompleted);
        }
 
        [Fact]
        public async Task TestDeterministicSweeperAsync()
        {
            var random = new Random(42);
            var env = new MLContext(42);
            var args = new DeterministicSweeperAsync.Options();
            args.BatchSize = 5;
            args.Relaxation = args.BatchSize - 1;
 
            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ => new SmacSweeper(environ,
                    new SmacSweeper.Options()
                    {
                        SweptParameters = new IComponentFactory<INumericValueGenerator>[] {
                                ComponentFactoryUtils.CreateFromFunction(
                                    t => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})),
                                ComponentFactoryUtils.CreateFromFunction(
                                    t => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
                        }
                    }));
 
            var sweeper = new DeterministicSweeperAsync(env, args);
 
            // Test single-threaded consumption.
            int sweeps = 10;
            var paramSets = new List<ParameterSet>();
            for (int i = 0; i < sweeps; i++)
            {
                var task = sweeper.ProposeAsync();
                Assert.True(task.IsCompleted);
                paramSets.Add(task.CompletedResult().ParameterSet);
                var result = new RunResult(task.CompletedResult().ParameterSet, random.NextDouble(), true);
                sweeper.Update(task.CompletedResult().Id, result);
            }
            Assert.Equal(sweeps, paramSets.Count);
            CheckAsyncSweeperResult(paramSets);
 
            // Create two batches and test if the 2nd batch is executed after the synchronization barrier is reached.
            object mlock = new object();
            var tasks = new Task<ParameterSetWithId>[sweeps];
            args.Relaxation = args.Relaxation - 1;
            sweeper = new DeterministicSweeperAsync(env, args);
            paramSets.Clear();
            var results = new List<KeyValuePair<int, IRunResult>>();
            for (int i = 0; i < args.BatchSize; i++)
            {
                var task = sweeper.ProposeAsync();
                Assert.True(task.IsCompleted);
                tasks[i] = task;
                if (task.CompletedResult() == null)
                    continue;
                results.Add(new KeyValuePair<int, IRunResult>(task.CompletedResult().Id, new RunResult(task.CompletedResult().ParameterSet, 0.42, true)));
            }
            // Register consumers for the 2nd batch. Those consumers will await until at least one run
            // in the previous batch has been posted to the sweeper.
            for (int i = args.BatchSize; i < 2 * args.BatchSize; i++)
            {
                var task = sweeper.ProposeAsync();
                Assert.False(task.IsCompleted);
                tasks[i] = task;
            }
            // Call update to unblock the 2nd batch.
            foreach (var run in results)
                sweeper.Update(run.Key, run.Value);
 
            await Task.WhenAll(tasks);
        }
 
        [Fact]
        public void TestDeterministicSweeperAsyncParallel()
        {
            var random = new Random(42);
            var env = new MLContext(42);
            const int batchSize = 5;
            const int sweeps = 20;
            var paramSets = new List<ParameterSet>();
            var args = new DeterministicSweeperAsync.Options();
            args.BatchSize = batchSize;
            args.Relaxation = batchSize - 2;
 
            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ => new SmacSweeper(environ,
                    new SmacSweeper.Options()
                    {
                        SweptParameters = new IComponentFactory<INumericValueGenerator>[] {
                                ComponentFactoryUtils.CreateFromFunction(
                                    t => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})),
                                ComponentFactoryUtils.CreateFromFunction(
                                    t => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
                        }
                    }));
 
            var sweeper = new DeterministicSweeperAsync(env, args);
 
            var mlock = new object();
            var options = new ParallelOptions();
            options.MaxDegreeOfParallelism = 4;
 
            // Sleep randomly to simulate doing work.
            int[] sleeps = new int[sweeps];
            for (int i = 0; i < sleeps.Length; i++)
                sleeps[i] = random.Next(10, 100);
            var r = Task.Run(() => Parallel.For(0, sweeps, options, async (int i) =>
            {
                var task = sweeper.ProposeAsync();
                var tResult = await task;
                Assert.Equal(TaskStatus.RanToCompletion, task.Status);
                var paramWithId = tResult;
                if (paramWithId == null)
                    return;
                Thread.Sleep(sleeps[i]);
                var result = new RunResult(paramWithId.ParameterSet, 0.42, true);
                sweeper.Update(paramWithId.Id, result);
                lock (mlock)
                    paramSets.Add(paramWithId.ParameterSet);
            }));
            Assert.True(paramSets.Count <= sweeps);
            CheckAsyncSweeperResult(paramSets);
        }
 
        [Fact]
        public async Task TestNelderMeadSweeperAsync()
        {
            var random = new Random(42);
            var env = new MLContext(42);
            const int batchSize = 5;
            const int sweeps = 40;
            var paramSets = new List<ParameterSet>();
            var args = new DeterministicSweeperAsync.Options();
            args.BatchSize = batchSize;
            args.Relaxation = 0;
 
            args.Sweeper = ComponentFactoryUtils.CreateFromFunction(
                environ =>
                {
                    var param = new IComponentFactory<INumericValueGenerator>[] {
                            ComponentFactoryUtils.CreateFromFunction(
                                innerEnviron => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})),
                            ComponentFactoryUtils.CreateFromFunction(
                                innerEnviron => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
                    };
 
                    var nelderMeadSweeperArgs = new NelderMeadSweeper.Options()
                    {
                        SweptParameters = param,
                        FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction<IValueGenerator[], ISweeper>(
                            (firstBatchSweeperEnviron, firstBatchSweeperArgs) =>
                                new RandomGridSweeper(environ, new RandomGridSweeper.Options() { SweptParameters = param }))
                    };
 
                    return new NelderMeadSweeper(environ, nelderMeadSweeperArgs);
                }
            );
 
            var sweeper = new DeterministicSweeperAsync(env, args);
            var mlock = new object();
            double[] metrics = new double[sweeps];
            for (int i = 0; i < metrics.Length; i++)
                metrics[i] = random.NextDouble();
 
            for (int i = 0; i < sweeps; i++)
            {
                var paramWithId = await sweeper.ProposeAsync();
                if (paramWithId == null)
                    return;
                var result = new RunResult(paramWithId.ParameterSet, metrics[i], true);
                sweeper.Update(paramWithId.Id, result);
                lock (mlock)
                    paramSets.Add(paramWithId.ParameterSet);
            }
            Assert.True(paramSets.Count <= sweeps);
            CheckAsyncSweeperResult(paramSets);
        }
 
        private void CheckAsyncSweeperResult(List<ParameterSet> paramSets)
        {
            Assert.NotNull(paramSets);
            foreach (var paramSet in paramSets)
            {
                foreach (var parameterValue in paramSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
                        Assert.InRange(val, 1, 5);
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.InRange(val, 1, 1000);
                    }
                    else
                    {
                        Assert.Fail("Wrong parameter");
                    }
                }
            }
        }
 
        [Fact]
        public void TestRandomGridSweeper()
        {
            var env = new MLContext(42);
            var args = new RandomGridSweeper.Options()
            {
                SweptParameters = new[] {
                        ComponentFactoryUtils.CreateFromFunction(
                            environ => new LongValueGenerator(new LongParamOptions() { Name = "foo", Min = 10, Max = 20, NumSteps = 3 })),
                        ComponentFactoryUtils.CreateFromFunction(
                            environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 100, Max = 10000, LogBase = true, StepSize = 10 }))
                    }
            };
            var sweeper = new RandomGridSweeper(env, args);
            var initialList = sweeper.ProposeSweeps(5, new List<RunResult>());
            Assert.Equal(5, initialList.Length);
            var gridPoint = new bool[3][] {
                    new bool[3],
                    new bool[3],
                    new bool[3]
                };
            int i = 0;
            int j = 0;
            foreach (var parameterSet in initialList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 10 || val == 15 || val == 20);
                        i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 100 || val == 1000 || val == 10000);
                        j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
                    }
                    else
                    {
                        Assert.Fail("Wrong parameter");
                    }
                }
                Assert.False(gridPoint[i][j]);
                gridPoint[i][j] = true;
            }
 
            var nextList = sweeper.ProposeSweeps(5, initialList.Select(p => new RunResult(p)));
            Assert.Equal(4, nextList.Length);
            foreach (var parameterSet in nextList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 10 || val == 15 || val == 20);
                        i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 100 || val == 1000 || val == 10000);
                        j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
                    }
                    else
                    {
                        Assert.Fail("Wrong parameter");
                    }
                }
                Assert.False(gridPoint[i][j]);
                gridPoint[i][j] = true;
            }
 
            gridPoint = new bool[3][] {
                    new bool[3],
                    new bool[3],
                    new bool[3]
                };
            var lastList = sweeper.ProposeSweeps(10, null);
            Assert.Equal(9, lastList.Length);
            foreach (var parameterSet in lastList)
            {
                foreach (var parameterValue in parameterSet)
                {
                    if (parameterValue.Name == "foo")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 10 || val == 15 || val == 20);
                        i = (val == 10) ? 0 : (val == 15) ? 1 : 2;
                    }
                    else if (parameterValue.Name == "bar")
                    {
                        var val = long.Parse(parameterValue.ValueText);
                        Assert.True(val == 100 || val == 1000 || val == 10000);
                        j = (val == 100) ? 0 : (val == 1000) ? 1 : 2;
                    }
                    else
                    {
                        Assert.Fail("Wrong parameter");
                    }
                }
                Assert.False(gridPoint[i][j]);
                gridPoint[i][j] = true;
            }
            Assert.True(gridPoint.All(bArray => bArray.All(b => b)));
        }
 
        [Fact]
        public void TestNelderMeadSweeper()
        {
            var random = new Random(42);
            var env = new MLContext(42);
            var param = new IComponentFactory<INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})),
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
                };
 
            var args = new NelderMeadSweeper.Options()
            {
                SweptParameters = param,
                FirstBatchSweeper = ComponentFactoryUtils.CreateFromFunction<IValueGenerator[], ISweeper>(
                    (environ, firstBatchArgs) =>
                    {
                        return new RandomGridSweeper(environ, new RandomGridSweeper.Options() { SweptParameters = param });
                    }
                )
            };
            var sweeper = new NelderMeadSweeper(env, args);
            var sweeps = sweeper.ProposeSweeps(5, new List<RunResult>());
            Assert.Equal(3, sweeps.Length);
 
            var results = new List<IRunResult>();
            for (int i = 1; i < 10; i++)
            {
                foreach (var parameterSet in sweeps)
                {
                    foreach (var parameterValue in parameterSet)
                    {
                        if (parameterValue.Name == "foo")
                        {
                            var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
                            Assert.InRange(val, 1, 5);
                        }
                        else if (parameterValue.Name == "bar")
                        {
                            var val = long.Parse(parameterValue.ValueText);
                            Assert.InRange(val, 1, 1000);
                        }
                        else
                        {
                            Assert.Fail("Wrong parameter");
                        }
                    }
                    results.Add(new RunResult(parameterSet, random.NextDouble(), true));
                }
 
                sweeps = sweeper.ProposeSweeps(5, results);
            }
            Assert.True(sweeps.Length <= 5);
        }
 
        [Fact]
        public void TestNelderMeadSweeperWithDefaultFirstBatchSweeper()
        {
            var random = new Random(42);
            var env = new MLContext(42);
            var param = new IComponentFactory<INumericValueGenerator>[] {
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})),
                    ComponentFactoryUtils.CreateFromFunction(
                        environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 1000, LogBase = true }))
                };
 
            var args = new NelderMeadSweeper.Options();
            args.SweptParameters = param;
            var sweeper = new NelderMeadSweeper(env, args);
            var sweeps = sweeper.ProposeSweeps(5, new List<RunResult>());
            Assert.Equal(3, sweeps.Length);
 
            var results = new List<IRunResult>();
            for (int i = 1; i < 10; i++)
            {
                foreach (var parameterSet in sweeps)
                {
                    foreach (var parameterValue in parameterSet)
                    {
                        if (parameterValue.Name == "foo")
                        {
                            var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
                            Assert.InRange(val, 1, 5);
                        }
                        else if (parameterValue.Name == "bar")
                        {
                            var val = long.Parse(parameterValue.ValueText);
                            Assert.InRange(val, 1, 1000);
                        }
                        else
                        {
                            Assert.Fail("Wrong parameter");
                        }
                    }
                    results.Add(new RunResult(parameterSet, random.NextDouble(), true));
                }
 
                sweeps = sweeper.ProposeSweeps(5, results);
            }
            Assert.True(sweeps == null || sweeps.Length <= 5);
        }
 
        [Fact]
        public void TestSmacSweeper()
        {
            var random = new Random(42);
            var env = new MLContext(42);
            const int maxInitSweeps = 5;
            var args = new SmacSweeper.Options()
            {
                NumberInitialPopulation = 20,
                SweptParameters = new IComponentFactory<INumericValueGenerator>[] {
                            ComponentFactoryUtils.CreateFromFunction(
                                environ => new FloatValueGenerator(new FloatParamOptions() { Name = "foo", Min = 1, Max = 5})),
                            ComponentFactoryUtils.CreateFromFunction(
                                environ => new LongValueGenerator(new LongParamOptions() { Name = "bar", Min = 1, Max = 100, LogBase = true }))
                        }
            };
 
            var sweeper = new SmacSweeper(env, args);
            var results = new List<IRunResult>();
            var sweeps = sweeper.ProposeSweeps(maxInitSweeps, results);
            Assert.Equal(Math.Min(args.NumberInitialPopulation, maxInitSweeps), sweeps.Length);
 
            for (int i = 1; i < 10; i++)
            {
                foreach (var parameterSet in sweeps)
                {
                    foreach (var parameterValue in parameterSet)
                    {
                        if (parameterValue.Name == "foo")
                        {
                            var val = float.Parse(parameterValue.ValueText, CultureInfo.InvariantCulture);
                            Assert.InRange(val, 1, 5);
                        }
                        else if (parameterValue.Name == "bar")
                        {
                            var val = long.Parse(parameterValue.ValueText);
                            Assert.InRange(val, 1, 1000);
                        }
                        else
                        {
                            Assert.Fail("Wrong parameter");
                        }
                    }
                    results.Add(new RunResult(parameterSet, random.NextDouble(), true));
                }
 
                sweeps = sweeper.ProposeSweeps(5, results);
            }
            // Because only unique configurations are considered, the number asked for may exceed the number actually returned.
            Assert.True(sweeps.Length <= 5);
        }
    }
}