File: Transformers\SelectColumnsTests.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Model;
using Microsoft.ML.RunTests;
using Microsoft.ML.Tools;
using Microsoft.ML.Transforms;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.Tests.Transformers
{
    public class SelectColumnsTransformsTests : TestDataPipeBase
    {
        class TestClass
        {
            public int A;
            public int B;
            public int C;
        }
 
        class TestClass2
        {
            public int D;
            public int E;
        }
        class TestClass3
        {
            public string Label;
            public string Features;
            public int A;
            public int B;
            public int C;
        };
 
        public SelectColumnsTransformsTests(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        public void TestSelectKeep()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var est = ColumnSelectingEstimator.KeepColumns(Env, "A", "C");
            var transformer = est.Fit(dataView);
            var result = transformer.Transform(dataView);
            var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx);
            var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx);
            var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx);
 
            Assert.True(foundColumnA);
            Assert.Equal(0, aIdx);
            Assert.False(foundColumnB);
            Assert.True(foundColumnC);
            Assert.Equal(1, cIdx);
        }
 
        [Fact]
        public void TestSelectKeepWithOrder()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
 
            // Expected output will be CA
            var est = ColumnSelectingEstimator.KeepColumns(Env, "C", "A");
            var transformer = est.Fit(dataView);
            var result = transformer.Transform(dataView);
            var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx);
            var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx);
            var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx);
 
            Assert.True(foundColumnA);
            Assert.Equal(1, aIdx);
            Assert.False(foundColumnB);
            Assert.True(foundColumnC);
            Assert.Equal(0, cIdx);
        }
 
        [Fact]
        public void TestSelectDrop()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var est = ColumnSelectingEstimator.DropColumns(Env, "A", "C");
            var transformer = est.Fit(dataView);
            var result = transformer.Transform(dataView);
            var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx);
            var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx);
            var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx);
 
            Assert.False(foundColumnA);
            Assert.True(foundColumnB);
            Assert.Equal(0, bIdx);
            Assert.False(foundColumnC);
        }
 
        [Fact]
        public void TestSelectWorkout()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var invalidData = new[] { new TestClass2 { D = 3, E = 5 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var invalidDataView = ML.Data.LoadFromEnumerable(invalidData);
 
            // Workout on keep columns
            var est = ML.Transforms.SelectColumns(new[] { "A", "B" });
            TestEstimatorCore(est, validFitInput: dataView, invalidInput: invalidDataView);
 
            // Workout on select columns with hidden: true
            est = ML.Transforms.SelectColumns(new[] { "A", "B" }, true);
            TestEstimatorCore(est, validFitInput: dataView, invalidInput: invalidDataView);
        }
 
        [Fact]
        public void TestSelectColumnsWithMissing()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var est = ColumnSelectingEstimator.KeepColumns(Env, "D", "G");
            Assert.Throws<ArgumentOutOfRangeException>(() => est.Fit(dataView));
        }
 
        [Fact]
        public void TestSelectColumnsWithSameName()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var est = new ColumnCopyingEstimator(Env, new[] { ("A", "A"), ("B", "B") });
            var chain = est.Append(ColumnSelectingEstimator.KeepColumns(Env, "C", "A"));
            var transformer = chain.Fit(dataView);
            var result = transformer.Transform(dataView);
 
            // Copied columns should equal AABBC, however we chose to keep A and C
            // so the result is AC
            Assert.Equal(2, result.Schema.Count);
            var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx);
            var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx);
            var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx);
            Assert.True(foundColumnA);
            Assert.Equal(1, aIdx);
            Assert.False(foundColumnB);
            Assert.True(foundColumnC);
            Assert.Equal(0, cIdx);
        }
 
        [Fact]
        public void TestSelectColumnsWithKeepHidden()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var est = new ColumnCopyingEstimator(Env, new[] { ("A", "A"), ("B", "B") });
            var chain = est.Append(ML.Transforms.SelectColumns(new[] { "B", "A" }, true));
            var transformer = chain.Fit(dataView);
            var result = transformer.Transform(dataView);
 
            // Input for SelectColumns should be AABBC, we chose to keep A and B
            // and keep hidden columns is true, therefore the output should be AABB
            Assert.Equal(4, result.Schema.Count);
            var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx);
            var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx);
            var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx);
            Assert.True(foundColumnA);
            Assert.Equal(3, aIdx);
            Assert.True(foundColumnB);
            Assert.Equal(1, bIdx);
            Assert.False(foundColumnC);
        }
 
        [Fact]
        public void TestSelectSavingAndLoading()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var est = ColumnSelectingEstimator.KeepColumns(Env, "A", "B");
            var transformer = est.Fit(dataView);
            using (var ms = new MemoryStream())
            {
                ML.Model.Save(transformer, null, ms);
                ms.Position = 0;
                var loadedTransformer = ML.Model.Load(ms, out var schema);
                var result = loadedTransformer.Transform(dataView);
                Assert.Equal(2, result.Schema.Count);
                Assert.Equal("A", result.Schema[0].Name);
                Assert.Equal("B", result.Schema[1].Name);
            }
        }
 
        [Fact]
        public void TestSelectSavingAndLoadingWithNoKeepHidden()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var est = new ColumnCopyingEstimator(Env, new[] { ("A", "A"), ("B", "B") }).Append(
                      ML.Transforms.SelectColumns(new[] { "A", "B" }, false));
            var transformer = est.Fit(dataView);
            using (var ms = new MemoryStream())
            {
                ML.Model.Save(transformer, null, ms);
                ms.Position = 0;
                var loadedTransformer = ML.Model.Load(ms, out var schema);
                var result = loadedTransformer.Transform(dataView);
                Assert.Equal(2, result.Schema.Count);
                Assert.Equal("A", result.Schema[0].Name);
                Assert.Equal("B", result.Schema[1].Name);
            }
        }
 
        [Fact]
        public void TestSelectBackCompatDropColumns()
        {
            // Model generated with: xf=drop{col=A} 
            // Expected output: Features Label B C
            var data = new[] { new TestClass3() { Label = "foo", Features = "bar", A = 1, B = 2, C = 3, } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            string dropModelPath = GetDataPath("backcompat/drop-model.zip");
            using (FileStream fs = File.OpenRead(dropModelPath))
            {
                var result = ModelFileUtils.LoadTransforms(Env, dataView, fs);
                var foundColumnFeature = result.Schema.TryGetColumnIndex("Features", out int featureIdx);
                var foundColumnLabel = result.Schema.TryGetColumnIndex("Label", out int labelIdx);
                var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx);
                var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx);
                var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx);
                Assert.True(foundColumnLabel);
                Assert.Equal(0, labelIdx);
                Assert.True(foundColumnFeature);
                Assert.Equal(1, featureIdx);
                Assert.False(foundColumnA);
                Assert.True(foundColumnB);
                Assert.Equal(2, bIdx);
                Assert.True(foundColumnC);
                Assert.Equal(3, cIdx);
            }
        }
 
        [Fact]
        public void TestSelectBackCompatKeepColumns()
        {
            // Model generated with: xf=keep{col=Label col=Features col=A col=B}
            // Expected output: Label Features A B
            var data = new[] { new TestClass3() { Label = "foo", Features = "bar", A = 1, B = 2, C = 3, } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            string dropModelPath = GetDataPath("backcompat/keep-model.zip");
            using (FileStream fs = File.OpenRead(dropModelPath))
            {
                var result = ModelFileUtils.LoadTransforms(Env, dataView, fs);
                var foundColumnFeature = result.Schema.TryGetColumnIndex("Features", out int featureIdx);
                var foundColumnLabel = result.Schema.TryGetColumnIndex("Label", out int labelIdx);
                var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx);
                var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx);
                var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx);
                Assert.True(foundColumnLabel);
                Assert.Equal(0, labelIdx);
                Assert.True(foundColumnFeature);
                Assert.Equal(1, featureIdx);
                Assert.True(foundColumnA);
                Assert.Equal(2, aIdx);
                Assert.True(foundColumnB);
                Assert.Equal(3, bIdx);
                Assert.False(foundColumnC);
            }
        }
 
        [Fact]
        public void TestSelectBackCompatChooseColumns()
        {
            // Model generated with: xf=choose{col=Label col=Features col=A col=B}
            // Output expected is Label Features A B
            var data = new[] { new TestClass3() { Label = "foo", Features = "bar", A = 1, B = 2, C = 3, } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            string dropModelPath = GetDataPath("backcompat/choose-model.zip");
            using (FileStream fs = File.OpenRead(dropModelPath))
            {
                var result = ModelFileUtils.LoadTransforms(Env, dataView, fs);
                var foundColumnFeature = result.Schema.TryGetColumnIndex("Features", out int featureIdx);
                var foundColumnLabel = result.Schema.TryGetColumnIndex("Label", out int labelIdx);
                var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx);
                var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx);
                var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx);
                Assert.True(foundColumnLabel);
                Assert.Equal(0, labelIdx);
                Assert.True(foundColumnFeature);
                Assert.Equal(1, featureIdx);
                Assert.True(foundColumnA);
                Assert.Equal(2, aIdx);
                Assert.True(foundColumnB);
                Assert.Equal(3, bIdx);
                Assert.False(foundColumnC);
            }
        }
 
        [Fact]
        public void TestSelectBackCompatChooseColumnsWithKeep()
        {
            // Model generated with: xf=copy{col=A:A col=B:B} xf=choose{col=Label col=Features col=A col=B hidden=keep}
            // Output expected is Label Features A A B B
            var data = new[] { new TestClass3() { Label = "foo", Features = "bar", A = 1, B = 2, C = 3, } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            string chooseModelPath = GetDataPath("backcompat/choose-keep-model.zip");
            using (FileStream fs = File.OpenRead(chooseModelPath))
            {
                var result = ModelFileUtils.LoadTransforms(Env, dataView, fs);
                Assert.Equal(6, result.Schema.Count);
                var foundColumnFeature = result.Schema.TryGetColumnIndex("Features", out int featureIdx);
                var foundColumnLabel = result.Schema.TryGetColumnIndex("Label", out int labelIdx);
                var foundColumnA = result.Schema.TryGetColumnIndex("A", out int aIdx);
                var foundColumnB = result.Schema.TryGetColumnIndex("B", out int bIdx);
                var foundColumnC = result.Schema.TryGetColumnIndex("C", out int cIdx);
                Assert.True(foundColumnLabel);
                Assert.Equal(0, labelIdx);
                Assert.True(foundColumnFeature);
                Assert.Equal(1, featureIdx);
                Assert.True(foundColumnA);
                Assert.Equal(3, aIdx);
                Assert.True(foundColumnB);
                Assert.Equal(5, bIdx);
                Assert.False(foundColumnC);
            }
        }
 
        [Fact]
        public void TestCommandLineWithKeep()
        {
            Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{keepcol=A keepcol=B} in=f:\1.txt" }));
        }
 
        [Fact]
        public void TestCommandLineWithDrop()
        {
            Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{dropcol=A dropcol=B} in=f:\1.txt" }));
        }
 
        [Fact]
        public void TestCommandLineKeepWithoutHidden()
        {
            Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{keepcol=A keepcol=B hidden=-} in=f:\1.txt" }));
        }
 
        [Fact]
        public void TestCommandLineKeepWithIgnoreMismatch()
        {
            Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0 col=B:R4:1 col=C:R4:2} xf=select{keepcol=A keepcol=B ignore=-} in=f:\1.txt" }));
        }
    }
}