File: Scenarios\Api\TestApi.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.Tests.Scenarios.Api
{
    public partial class TestApi : BaseTestClass
    {
        private const string OutputRelativePath = @"../Common/Api";
 
        public TestApi(ITestOutputHelper output) : base(output)
        {
        }
 
        private class OneIChannelWithAttribute
        {
            [CursorChannel]
            public IChannel Channel = null;
        }
 
        private class OneStringWithAttribute
        {
            public string OutField = "outfield is a string";
 
            [CursorChannelAttribute]
            public string Channel = null;
        }
 
        private class TwoIChannelsWithAttributes
        {
            [CursorChannelAttribute]
            public IChannel ChannelOne = null;
 
            [CursorChannelAttribute]
            public IChannel ChannelTwo = null;
        }
 
        private class TwoIChannelsOnlyOneWithAttribute
        {
            public string OutField = "Outfield is assigned";
 
            public IChannel ChannelOne = null;
 
            [CursorChannelAttribute]
            public IChannel ChannelTwo = null;
        }
 
        [Fact]
        public void CursorChannelExposedInMapTransform()
        {
            var env = new MLContext(seed: 0);
            // Correct use of CursorChannel attribute.
            var data1 = Utils.CreateArray(10, new OneIChannelWithAttribute());
            var idv1 = env.Data.LoadFromEnumerable(data1);
            Assert.Null(data1[0].Channel);
 
            var filter1 = LambdaTransform.CreateFilter<OneIChannelWithAttribute, object>(env, idv1,
                (input, state) =>
                {
                    Assert.NotNull(input.Channel);
                    return false;
                }, null);
            filter1.GetRowCursorForAllColumns().MoveNext();
 
            // Error case: non-IChannel field marked with attribute.
            var data2 = Utils.CreateArray(10, new OneStringWithAttribute());
            var idv2 = env.Data.LoadFromEnumerable(data2);
            Assert.Null(data2[0].Channel);
 
            var filter2 = LambdaTransform.CreateFilter<OneStringWithAttribute, object>(env, idv2,
                (input, state) =>
                {
                    Assert.Null(input.Channel);
                    return false;
                }, null);
            try
            {
                filter2.GetRowCursorForAllColumns().MoveNext();
                Assert.Fail("Throw an error if attribute is applied to a field that is not an IChannel.");
            }
            catch (InvalidOperationException ex)
            {
                Assert.True(ex.IsMarked());
            }
 
            // Error case: multiple fields marked with attributes.
            var data3 = Utils.CreateArray(10, new TwoIChannelsWithAttributes());
            var idv3 = env.Data.LoadFromEnumerable(data3);
            Assert.Null(data3[0].ChannelOne);
            Assert.Null(data3[2].ChannelTwo);
 
            var filter3 = LambdaTransform.CreateFilter<TwoIChannelsWithAttributes, object>(env, idv3,
                (input, state) =>
                {
                    Assert.Null(input.ChannelOne);
                    Assert.Null(input.ChannelTwo);
                    return false;
                }, null);
            try
            {
                filter3.GetRowCursorForAllColumns().MoveNext();
                Assert.Fail("Throw an error if attribute is applied to a field that is not an IChannel.");
            }
            catch (InvalidOperationException ex)
            {
                Assert.True(ex.IsMarked());
            }
 
            // Correct case: non-marked IChannel field is not touched.
            var example4 = new TwoIChannelsOnlyOneWithAttribute();
            Assert.Null(example4.ChannelTwo);
            Assert.Null(example4.ChannelOne);
            var idv4 = env.Data.LoadFromEnumerable(Utils.CreateArray(10, example4));
 
            var filter4 = LambdaTransform.CreateFilter<TwoIChannelsOnlyOneWithAttribute, object>(env, idv4,
                (input, state) =>
                {
                    Assert.Null(input.ChannelOne);
                    Assert.NotNull(input.ChannelTwo);
                    return false;
                }, null);
            filter1.GetRowCursorForAllColumns().MoveNext();
        }
 
        public class BreastCancerExample
        {
            public float Label;
 
            [VectorType(9)]
            public float[] Features;
        }
 
        [Fact]
        public void LambdaTransformCreate()
        {
            var env = new MLContext(seed: 42);
            var data = ReadBreastCancerExamples();
            var idv = env.Data.LoadFromEnumerable(data);
 
            var filter = LambdaTransform.CreateFilter<BreastCancerExample, object>(env, idv,
                (input, state) => input.Label == 0, null);
 
            Assert.Null(filter.GetRowCount());
 
            // test re-apply
            var applied = env.Data.LoadFromEnumerable(data);
            applied = ApplyTransformUtils.ApplyAllTransformsToData(env, filter, applied);
 
            var saver = new TextSaver(env, new TextSaver.Arguments());
            Assert.True(applied.Schema.TryGetColumnIndex("Label", out int label));
            using (var fs = File.Create(GetOutputPath(OutputRelativePath, "lambda-output.tsv")))
                saver.SaveData(fs, applied, label);
        }
 
        [Fact]
        public void TrainAveragedPerceptronWithCache()
        {
            var mlContext = new MLContext(0);
            var dataFile = GetDataPath(TestDatasets.breastCancer.trainFilename);
            var loader = TextLoader.Create(mlContext, new TextLoader.Options(), new MultiFileSource(dataFile));
            var globalCounter = 0;
            IDataView xf = LambdaTransform.CreateFilter<object, object>(mlContext, loader,
                (i, s) => true,
                s => { globalCounter++; });
            xf = mlContext.Transforms.Conversion.ConvertType("Label", outputKind: DataKind.Boolean).Fit(xf).Transform(xf);
            // The baseline result of this was generated with everything cached in memory. As auto-cache is removed,
            // an explicit step of caching is required to make this test ok.
            var cached = mlContext.Data.Cache(xf);
 
            var estimator = mlContext.BinaryClassification.Trainers.AveragedPerceptron(
                new AveragedPerceptronTrainer.Options { NumberOfIterations = 2 });
 
            estimator.Fit(cached).Transform(cached);
 
            // Make sure there were 2 cursoring events.
            Assert.Equal(1, globalCounter);
        }
 
        [Fact]
        public void MetadataSupportInDataViewConstruction()
        {
            var data = ReadBreastCancerExamples();
            var autoSchema = SchemaDefinition.Create(typeof(BreastCancerExample));
 
            var mlContext = new MLContext(0);
 
            // Create Metadata.
            var kindFloat = "Testing float as metadata.";
            float valueFloat = 10;
            var coltypeFloat = NumberDataViewType.Single;
            var kindString = "Testing string as metadata.";
            var valueString = "Strings have value.";
            var coltypeString = TextDataViewType.Instance;
            var kindStringArray = "Testing string array as metadata.";
            var valueStringArray = "I really have no idea what these features entail.".Split(' ');
            var coltypeStringArray = new VectorDataViewType(coltypeString, valueStringArray.Length);
            var kindFloatArray = "Testing float array as metadata.";
            var valueFloatArray = new float[] { 1, 17, 7, 19, 25, 0 };
            var coltypeFloatArray = new VectorDataViewType(coltypeFloat, valueFloatArray.Length);
            var kindVBuffer = "Testing VBuffer as metadata.";
            var valueVBuffer = new VBuffer<float>(4, new float[] { 4, 6, 89, 5 });
            var coltypeVBuffer = new VectorDataViewType(coltypeFloat, valueVBuffer.Length);
 
            // Add Metadata.
            var labelColumn = autoSchema[0];
            labelColumn.AddAnnotation(kindFloat, valueFloat, coltypeFloat);
            labelColumn.AddAnnotation(kindString, valueString, coltypeString);
 
            var featureColumn = autoSchema[1];
            featureColumn.AddAnnotation(kindStringArray, valueStringArray, coltypeStringArray);
            featureColumn.AddAnnotation(kindFloatArray, valueFloatArray, coltypeFloatArray);
            featureColumn.AddAnnotation(kindVBuffer, valueVBuffer, coltypeVBuffer);
 
            var idv = mlContext.Data.LoadFromEnumerable(data, autoSchema);
 
            Assert.True(idv.Schema[0].Annotations.Schema.Count == 2);
            Assert.True(idv.Schema[0].Annotations.Schema[0].Name == kindFloat);
            Assert.True(idv.Schema[0].Annotations.Schema[0].Type == coltypeFloat);
            Assert.True(idv.Schema[0].Annotations.Schema[1].Name == kindString);
            Assert.True(idv.Schema[0].Annotations.Schema[1].Type == TextDataViewType.Instance);
 
            Assert.True(idv.Schema[1].Annotations.Schema.Count == 3);
            Assert.True(idv.Schema[1].Annotations.Schema[0].Name == kindStringArray);
            Assert.True(idv.Schema[1].Annotations.Schema[0].Type is VectorDataViewType vectorType && vectorType.ItemType is TextDataViewType);
            Assert.Throws<ArgumentOutOfRangeException>(() => idv.Schema[1].Annotations.Schema[kindFloat]);
 
            float retrievedFloat = 0;
            idv.Schema[0].Annotations.GetValue(kindFloat, ref retrievedFloat);
            Assert.True(Math.Abs(retrievedFloat - valueFloat) < .000001);
 
            ReadOnlyMemory<char> retrievedReadOnlyMemory = new ReadOnlyMemory<char>();
            idv.Schema[0].Annotations.GetValue(kindString, ref retrievedReadOnlyMemory);
            Assert.True(retrievedReadOnlyMemory.Span.SequenceEqual(valueString.AsMemory().Span));
 
            VBuffer<ReadOnlyMemory<char>> retrievedReadOnlyMemoryVBuffer = new VBuffer<ReadOnlyMemory<char>>();
            idv.Schema[1].Annotations.GetValue(kindStringArray, ref retrievedReadOnlyMemoryVBuffer);
            Assert.True(retrievedReadOnlyMemoryVBuffer.DenseValues().Select((s, i) => s.ToString() == valueStringArray[i]).All(b => b));
 
            VBuffer<float> retrievedFloatVBuffer = new VBuffer<float>(1, new float[] { 2 });
            idv.Schema[1].Annotations.GetValue(kindFloatArray, ref retrievedFloatVBuffer);
            VBuffer<float> valueFloatVBuffer = new VBuffer<float>(valueFloatArray.Length, valueFloatArray);
            Assert.True(retrievedFloatVBuffer.Items().SequenceEqual(valueFloatVBuffer.Items()));
 
            VBuffer<float> retrievedVBuffer = new VBuffer<float>();
            idv.Schema[1].Annotations.GetValue(kindVBuffer, ref retrievedVBuffer);
            Assert.True(retrievedVBuffer.Items().SequenceEqual(valueVBuffer.Items()));
 
            Assert.Throws<InvalidOperationException>(() => idv.Schema[1].Annotations.GetValue(kindFloat, ref retrievedReadOnlyMemoryVBuffer));
        }
 
        private List<BreastCancerExample> ReadBreastCancerExamples()
        {
            var dataFile = GetDataPath(TestDatasets.breastCancer.trainFilename);
 
            // read the data programmatically into memory
            var data = File.ReadAllLines(dataFile)
                .Where(line => !string.IsNullOrEmpty(line) && !line.StartsWith("//"))
                .Select(
                    line =>
                    {
                        var parts = line.Split('\t');
                        var ex = new BreastCancerExample
                        {
                            Features = new float[9]
                        };
                        var span = new ReadOnlySpan<char>(parts[0].ToCharArray());
                        DoubleParser.Parse(span, out ex.Label);
                        for (int j = 0; j < 9; j++)
                        {
                            span = new ReadOnlySpan<char>(parts[j + 1].ToCharArray());
                            DoubleParser.Parse(span, out ex.Features[j]);
                        }
                        return ex;
                    })
                .ToList();
            return data;
        }
 
        [Fact]
        public void TestSplitsSchema()
        {
 
            var mlContext = new MLContext(0);
            var dataPath = GetDataPath("adult.tiny.with-schema.txt");
 
            var fullInput = mlContext.Data.LoadFromTextFile(dataPath, new[] {
                            new TextLoader.Column("Label", DataKind.Boolean, 0),
                            new TextLoader.Column("Workclass", DataKind.String, 1),
                            new TextLoader.Column("Education", DataKind.String,2),
                            new TextLoader.Column("Age", DataKind.Single,9)
            }, hasHeader: true);
 
            var ttSplit = mlContext.Data.TrainTestSplit(fullInput);
            var ttSplitWithSeed = mlContext.Data.TrainTestSplit(fullInput, seed: 10);
            var ttSplitWithSeedAndSamplingKey = mlContext.Data.TrainTestSplit(fullInput, seed: 10, samplingKeyColumnName: "Workclass");
 
            var cvSplit = mlContext.Data.CrossValidationSplit(fullInput);
            var cvSplitWithSeed = mlContext.Data.CrossValidationSplit(fullInput, seed: 10);
            var cvSplitWithSeedAndSamplingKey = mlContext.Data.CrossValidationSplit(fullInput, seed: 10, samplingKeyColumnName: "Workclass");
 
            var splits = new[]
            {
                ttSplit.TrainSet,
                ttSplit.TestSet,
                ttSplitWithSeed.TrainSet,
                ttSplitWithSeed.TestSet,
                ttSplitWithSeedAndSamplingKey.TrainSet,
                ttSplitWithSeedAndSamplingKey.TestSet,
                cvSplit.First().TrainSet,
                cvSplit.First().TestSet,
                cvSplitWithSeed.First().TrainSet,
                cvSplitWithSeed.First().TestSet,
                cvSplitWithSeedAndSamplingKey.First().TrainSet,
                cvSplitWithSeedAndSamplingKey.First().TestSet
            };
 
            // Splitting a dataset shouldn't affect its schema
            foreach (var split in splits)
            {
                Assert.Equal(fullInput.Schema.Count, split.Schema.Count);
                foreach (var col in fullInput.Schema)
                {
                    Assert.Equal(col.Name, split.Schema[col.Index].Name);
                }
            }
        }
 
        [Fact]
        public void TestTrainTestSplit()
        {
            var mlContext = new MLContext(0);
            var dataPath = GetDataPath("adult.tiny.with-schema.txt");
            // Create the reader: define the data columns and where to find them in the text file.
            var input = mlContext.Data.LoadFromTextFile(dataPath, new[] {
                            new TextLoader.Column("Label", DataKind.Boolean, 0),
                            new TextLoader.Column("Workclass", DataKind.String, 1),
                            new TextLoader.Column("Education", DataKind.String,2),
                            new TextLoader.Column("Age", DataKind.Single,9)
            }, hasHeader: true);
            // this function will accept dataview and return content of "Workclass" column as List of strings.
            Func<IDataView, List<string>> getWorkclass = (IDataView view) =>
            {
                return view.GetColumn<ReadOnlyMemory<char>>(view.Schema["Workclass"]).Select(x => x.ToString()).ToList();
            };
 
            // Let's test what train test properly works with seed.
            // In order to do that, let's split same dataset, but in one case we will use default seed value,
            // and in other case we set seed to be specific value.
            var simpleSplit = mlContext.Data.TrainTestSplit(input);
            var splitWithSeed = mlContext.Data.TrainTestSplit(input, seed: 10);
 
            // Since test fraction is 0.1, it's much faster to compare test subsets of split.
            var simpleTestWorkClass = getWorkclass(simpleSplit.TestSet);
 
            var simpleWithSeedTestWorkClass = getWorkclass(splitWithSeed.TestSet);
            // Validate we get different test sets.
            Assert.NotEqual(simpleTestWorkClass, simpleWithSeedTestWorkClass);
 
            // Now let's do same thing but with presence of stratificationColumn.
            // Rows with same values in this stratificationColumn should end up in same subset (train or test).
            // So let's break dataset by "Workclass" column.
            var stratSplit = mlContext.Data.TrainTestSplit(input, samplingKeyColumnName: "Workclass");
            var stratTrainWorkclass = getWorkclass(stratSplit.TrainSet);
            var stratTestWorkClass = getWorkclass(stratSplit.TestSet);
            // Let's get unique values for "Workclass" column from train subset.
            var uniqueTrain = stratTrainWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList();
            // and from test subset.
            var uniqueTest = stratTestWorkClass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList();
            // Validate we don't have intersection between workclass values since we use that column as stratification column
            Assert.True(Enumerable.Intersect(uniqueTrain, uniqueTest).Count() == 0);
 
            // Let's do same thing, but this time we will choose different seed.
            // Stratification column should still break dataset properly without same values in both subsets.
            var stratSeed = mlContext.Data.TrainTestSplit(input, samplingKeyColumnName: "Workclass", seed: 1000000);
            var stratTrainWithSeedWorkclass = getWorkclass(stratSeed.TrainSet);
            var stratTestWithSeedWorkClass = getWorkclass(stratSeed.TestSet);
            // Let's get unique values for "Workclass" column from train subset.
            var uniqueSeedTrain = stratTrainWithSeedWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList();
            // and from test subset.
            var uniqueSeedTest = stratTestWithSeedWorkClass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList();
 
            // Validate we don't have intersection between workclass values since we use that column as stratification column
            Assert.True(Enumerable.Intersect(uniqueSeedTrain, uniqueSeedTest).Count() == 0);
            // Validate we got different test results on same stratification column with different seeds
            Assert.NotEqual(uniqueTest, uniqueSeedTest);
        }
 
        private sealed class Input
        {
            public int Id { get; set; }
            public string TextStrat { get; set; }
            public float FloatStrat { get; set; }
            [VectorType(4)]
            public float[] VectorStrat { get; set; }
            public DateTime DateTimeStrat { get; set; }
            public DateTimeOffset DateTimeOffsetStrat { get; set; }
            public TimeSpan TimeSpanStrat { get; set; }
        }
 
        [Fact]
        public void TestSplitsWithSamplingKeyColumn()
        {
            var mlContext = new MLContext(0);
            var input = mlContext.Data.LoadFromEnumerable(new[]
            {
                new Input() {
                    Id = 0, TextStrat = "a", FloatStrat = 3, VectorStrat = new float[]{ 2, 3, 4, 5 }, DateTimeStrat = new DateTime(2002, 2, 23),
                    DateTimeOffsetStrat = new DateTimeOffset(2002, 2, 23, 3, 30, 0, new TimeSpan(1, 0, 0)), TimeSpanStrat = new TimeSpan(2, 0, 0)
                },
                new Input() {
                    Id = 1, TextStrat = "b", FloatStrat = 3, VectorStrat = new float[]{ 1, 2, 3, 4 }, DateTimeStrat = new DateTime(2020, 2, 23),
                    DateTimeOffsetStrat = new DateTimeOffset(2002, 2, 23, 3, 30, 0, new TimeSpan(2, 0, 0)), TimeSpanStrat = new TimeSpan(2, 0, 10)
                },
                new Input() {
                    Id = 2, TextStrat = "c", FloatStrat = 4, VectorStrat = new float[]{ 3, 4, 5, 6 }, DateTimeStrat = new DateTime(2018, 2, 23),
                    DateTimeOffsetStrat = new DateTimeOffset(2002, 2, 23, 3, 30, 0, new TimeSpan(1, 0, 0)), TimeSpanStrat = new TimeSpan(2, 0, 10)
                },
                new Input() {
                    Id = 3, TextStrat = "d", FloatStrat = 4, VectorStrat = new float[]{ 4, 5, 6, 7 }, DateTimeStrat = new DateTime(2016, 2, 23),
                    DateTimeOffsetStrat = new DateTimeOffset(2002, 2, 23, 3, 30, 0, new TimeSpan(2, 0, 0)), TimeSpanStrat = new TimeSpan(2, 0, 0)
                },
                new Input() {
                    Id = 4, TextStrat = "a", FloatStrat = -493.28f, VectorStrat = new float[]{ 2, 3, 4, 5 }, DateTimeStrat = new DateTime(2016, 2, 23),
                    DateTimeOffsetStrat = new DateTimeOffset(2002, 2, 23, 3, 30, 0, new TimeSpan(3, 0, 0)), TimeSpanStrat = new TimeSpan(2, 0, 20)
                },
                new Input() {
                    Id = 5, TextStrat = "b", FloatStrat = -493.28f, VectorStrat = new float[]{ 1, 2, 3, 4 }, DateTimeStrat = new DateTime(2018, 2, 23),
                    DateTimeOffsetStrat = new DateTimeOffset(2002, 2, 23, 3, 30, 0, new TimeSpan(4, 0, 0)), TimeSpanStrat = new TimeSpan(2, 0, 30)
                },
                new Input() {
                    Id = 6, TextStrat = "c", FloatStrat = 6, VectorStrat = new float[]{ 3, 4, 5, 6 }, DateTimeStrat = new DateTime(2020,2 , 23),
                    DateTimeOffsetStrat = new DateTimeOffset(2002, 2, 23, 3, 30, 0, new TimeSpan(3, 0, 0)), TimeSpanStrat = new TimeSpan(2, 0, 30)
                },
                new Input() {
                    Id = 7, TextStrat = "d", FloatStrat = 6, VectorStrat = new float[]{ 4, 5, 6, 7 }, DateTimeStrat = new DateTime(2002, 2, 23),
                    DateTimeOffsetStrat = new DateTimeOffset(2002, 2, 23, 3, 30, 0, new TimeSpan(4, 0, 0)), TimeSpanStrat = new TimeSpan(2, 0, 20)
                },
            });
 
            // TEST TRAINTESTSPLIT
            var split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.TextStrat));
            var ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
            Assert.Contains(1, ids);
            Assert.Contains(5, ids);
            split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.FloatStrat));
            ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
            Assert.Contains(4, ids);
            Assert.Contains(5, ids);
            split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.VectorStrat));
            ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
            Assert.Contains(0, ids);
            Assert.Contains(4, ids);
            split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.DateTimeStrat));
            ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
            Assert.Contains(5, ids);
            Assert.Contains(6, ids);
            split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.DateTimeOffsetStrat));
            ids = split.TrainSet.GetColumn<int>(split.TrainSet.Schema[nameof(Input.Id)]);
            Assert.Contains(4, ids);
            Assert.Contains(7, ids);
            split = mlContext.Data.TrainTestSplit(input, 0.5, nameof(Input.TimeSpanStrat));
            ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
            Assert.Contains(1, ids);
            Assert.Contains(2, ids);
 
            var inputWithKey = mlContext.Transforms.Conversion.MapValueToKey("KeyStrat", "TextStrat").Fit(input).Transform(input);
            split = mlContext.Data.TrainTestSplit(inputWithKey, 0.5, "KeyStrat");
            ids = split.TestSet.GetColumn<int>(split.TestSet.Schema[nameof(Input.Id)]);
            Assert.Contains(1, ids);
            Assert.Contains(5, ids);
            Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull("KeyStrat")); // Check that the key column used as SamplingKeyColumn wasn't deleted by the split
 
            // TEST CROSSVALIDATIONSPLIT
            var colnames = new[] {
                nameof(Input.TextStrat),
                nameof(Input.FloatStrat),
                nameof(Input.VectorStrat),
                nameof(Input.DateTimeStrat),
                nameof(Input.DateTimeOffsetStrat),
                nameof(Input.TimeSpanStrat),
                "KeyStrat" };
 
            foreach (var colname in colnames)
            {
                var cvSplits = mlContext.Data.CrossValidationSplit(inputWithKey, numberOfFolds: 2, samplingKeyColumnName: colname);
                var idsTest1 = cvSplits[0].TestSet.GetColumn<int>(cvSplits[0].TestSet.Schema[nameof(Input.Id)]);
                var idsTest2 = cvSplits[1].TestSet.GetColumn<int>(cvSplits[1].TestSet.Schema[nameof(Input.Id)]);
                Assert.True(Enumerable.Intersect(idsTest1, idsTest2).Count() == 0);
                Assert.True(idsTest1.Count() > 0, $"CV Split 0 for Column {colname} was empty");
                Assert.True(idsTest2.Count() > 0, $"CV Split 1 for Column {colname} was empty");
 
                // Check that using CV didn't remove the SamplingKeyColumn
                Assert.NotNull(split.TrainSet.Schema.GetColumnOrNull(colname));
            }
        }
    }
}