File: Transformers\ConcatTests.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.RunTests;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.Tests.Transformers
{
    public sealed class ConcatTests : TestDataPipeBase
    {
        public ConcatTests(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        public void TestConcatNoInputColumns()
        {
            var thrown = false;
 
            try
            {
                var pipe = ML.Transforms.Concatenate("Features");
            }
            catch (Exception ex)
            {
                Assert.Contains("Input columns not specified", ex.Message);
                thrown = true;
 
            }
            Assert.True(thrown);
            Done();
        }
 
        [Fact]
        public void TestConcat()
        {
            string dataPath = GetDataPath("adult.tiny.with-schema.txt");
 
            var source = new MultiFileSource(dataPath);
            var loader = new TextLoader(ML, new TextLoader.Options
            {
                Columns = new[]{
                    new TextLoader.Column("float1", DataKind.Single, 9),
                    new TextLoader.Column("float4", DataKind.Single, new[]{new TextLoader.Range(9), new TextLoader.Range(10), new TextLoader.Range(11), new TextLoader.Range(12) }),
                    new TextLoader.Column("float6", DataKind.Single, new[]{new TextLoader.Range(9), new TextLoader.Range(10), new TextLoader.Range(11), new TextLoader.Range(12, 14) }),
                    new TextLoader.Column("vfloat", DataKind.Single, new[]{new TextLoader.Range(14, null) { AutoEnd = false, VariableEnd = true } })
                },
                Separator = "\t",
                HasHeader = true
            }, new MultiFileSource(dataPath));
            var data = loader.Load(source);
 
            static DataViewType GetType(DataViewSchema schema, string name)
            {
                Assert.True(schema.TryGetColumnIndex(name, out int cIdx), $"Could not find '{name}'");
                return schema[cIdx].Type;
            }
 
            var pipe = ML.Transforms.Concatenate("f1", "float1")
                .Append(ML.Transforms.Concatenate("f2", "float1", "float1"))
                .Append(ML.Transforms.Concatenate("f3", "float4", "float1"))
                .Append(ML.Transforms.Concatenate("f4", "float6", "vfloat", "float1"));
 
            data = ML.Data.TakeRows(data, 10);
            data = pipe.Fit(data).Transform(data);
 
            DataViewType t;
            t = GetType(data.Schema, "f1");
            Assert.True(t is VectorDataViewType vt1 && vt1.ItemType == NumberDataViewType.Single && vt1.Size == 1);
            t = GetType(data.Schema, "f2");
            Assert.True(t is VectorDataViewType vt2 && vt2.ItemType == NumberDataViewType.Single && vt2.Size == 2);
            t = GetType(data.Schema, "f3");
            Assert.True(t is VectorDataViewType vt3 && vt3.ItemType == NumberDataViewType.Single && vt3.Size == 5);
            t = GetType(data.Schema, "f4");
            Assert.True(t is VectorDataViewType vt4 && vt4.ItemType == NumberDataViewType.Single && vt4.Size == 0);
 
            data = ML.Transforms.SelectColumns("f1", "f2", "f3", "f4").Fit(data).Transform(data);
 
            var subdir = Path.Combine("Transform", "Concat");
            var outputPath = GetOutputPath(subdir, "Concat1.tsv");
            using (var ch = Env.Start("save"))
            {
                var saver = new TextSaver(ML, new TextSaver.Arguments { Silent = true, Dense = true });
                using (var fs = File.Create(outputPath))
                    DataSaverUtils.SaveDataView(ch, saver, data, fs, keepHidden: false);
            }
 
            CheckEquality(subdir, "Concat1.tsv");
            Done();
        }
 
        [Fact]
        public void ConcatWithAliases()
        {
            string dataPath = GetDataPath("adult.tiny.with-schema.txt");
 
            var source = new MultiFileSource(dataPath);
            var loader = new TextLoader(ML, new TextLoader.Options
            {
                Columns = new[]{
                    new TextLoader.Column("float1", DataKind.Single, 9),
                    new TextLoader.Column("float4", DataKind.Single, new[]{new TextLoader.Range(9), new TextLoader.Range(10), new TextLoader.Range(11), new TextLoader.Range(12) }),
                    new TextLoader.Column("vfloat", DataKind.Single, new[]{new TextLoader.Range(9), new TextLoader.Range(10), new TextLoader.Range(11), new TextLoader.Range(12, null) { AutoEnd = false, VariableEnd = true } })
                },
                Separator = "\t",
                HasHeader = true
            }, new MultiFileSource(dataPath));
            var data = loader.Load(source);
 
            static DataViewType GetType(DataViewSchema schema, string name)
            {
                Assert.True(schema.TryGetColumnIndex(name, out int cIdx), $"Could not find '{name}'");
                return schema[cIdx].Type;
            }
 
            data = ML.Data.TakeRows(data, 10);
 
            var concater = new ColumnConcatenatingTransformer(ML,
                new ColumnConcatenatingTransformer.ColumnOptions("f2", new[] { ("float1", "FLOAT1"), ("float1", "FLOAT2") }),
                new ColumnConcatenatingTransformer.ColumnOptions("f3", new[] { ("float4", "FLOAT4"), ("float1", "FLOAT1") }));
            data = concater.Transform(data);
 
            // Test Columns property.
            var columns = concater.Columns;
            var colEnumerator = columns.GetEnumerator();
            colEnumerator.MoveNext();
            Assert.True(colEnumerator.Current.outputColumnName == "f2" &&
                colEnumerator.Current.inputColumnNames[0] == "float1" &&
                colEnumerator.Current.inputColumnNames[1] == "float1");
            colEnumerator.MoveNext();
            Assert.True(colEnumerator.Current.outputColumnName == "f3" &&
                colEnumerator.Current.inputColumnNames[0] == "float4" &&
                colEnumerator.Current.inputColumnNames[1] == "float1");
 
            DataViewType t;
            t = GetType(data.Schema, "f2");
            Assert.True(t is VectorDataViewType vt2 && vt2.ItemType == NumberDataViewType.Single && vt2.Size == 2);
            t = GetType(data.Schema, "f3");
            Assert.True(t is VectorDataViewType vt3 && vt3.ItemType == NumberDataViewType.Single && vt3.Size == 5);
 
            data = ML.Transforms.SelectColumns("f2", "f3").Fit(data).Transform(data);
 
            var subdir = Path.Combine("Transform", "Concat");
            var outputPath = GetOutputPath(subdir, "Concat2.tsv");
            using (var ch = Env.Start("save"))
            {
                var saver = new TextSaver(ML, new TextSaver.Arguments { Silent = true, Dense = true });
                using (var fs = File.Create(outputPath))
                    DataSaverUtils.SaveDataView(ch, saver, data, fs, keepHidden: false);
            }
 
            CheckEquality(subdir, "Concat2.tsv");
            Done();
        }
    }
}