File: SplitUtilTests.cs
Web Access
Project: src\test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj (Microsoft.ML.AutoML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.AutoML.Test
{
 
    public class SplitUtilTests : BaseTestClass
    {
        public SplitUtilTests(ITestOutputHelper output) : base(output)
        {
        }
 
        /// <summary>
        /// When there's only one row of data, assert that
        /// attempted cross validation throws (all splits should have empty
        /// train or test set).
        /// </summary>
        [Fact]
        public void CrossValSplitThrowsWhenNotEnoughData()
        {
            var mlContext = new MLContext(1);
            var dataViewBuilder = new ArrayDataViewBuilder(mlContext);
            dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
            dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, 0f);
            var dataView = dataViewBuilder.GetDataView();
            Assert.Throws<InvalidOperationException>(() => SplitUtil.CrossValSplit(mlContext, dataView, 10, null));
        }
 
        /// <summary>
        /// When there are few rows of data, assert that
        /// cross validation succeeds, but # of splits is less than 10
        /// (splits with empty train or test sets should not be returned from this API).
        /// </summary>
        [Fact]
        public void CrossValSplitSmallDataView()
        {
            var mlContext = new MLContext(seed: 0);
            var dataViewBuilder = new ArrayDataViewBuilder(mlContext);
            dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, new float[9]);
            dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, new float[9]);
            var dataView = dataViewBuilder.GetDataView();
            const int requestedNumSplits = 10;
            var splits = SplitUtil.CrossValSplit(mlContext, dataView, requestedNumSplits, null);
            Assert.True(splits.trainDatasets.Any());
            Assert.True(splits.trainDatasets.Count() < requestedNumSplits);
            Assert.Equal(splits.trainDatasets.Count(), splits.validationDatasets.Count());
        }
 
        /// <summary>
        /// Assert that with many rows of data, cross validation produces the requested
        /// # of splits.
        /// </summary>
        [Fact]
        public void CrossValSplitLargeDataView()
        {
            var mlContext = new MLContext(seed: 0);
            var dataViewBuilder = new ArrayDataViewBuilder(mlContext);
            dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, new float[10000]);
            dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, new float[10000]);
            var dataView = dataViewBuilder.GetDataView();
            const int requestedNumSplits = 10;
            var splits = SplitUtil.CrossValSplit(mlContext, dataView, requestedNumSplits, null);
            Assert.True(splits.trainDatasets.Any());
            Assert.Equal(requestedNumSplits, splits.trainDatasets.Count());
            Assert.Equal(requestedNumSplits, splits.validationDatasets.Count());
        }
    }
}