File: Transformers\CopyColumnEstimatorTests.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFramework;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Tools;
using Microsoft.ML.Transforms;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.Tests
{
    public class CopyColumnEstimatorTests : BaseTestClass
    {
        public CopyColumnEstimatorTests(ITestOutputHelper output) : base(output)
        {
        }
 
        class TestClass
        {
            public int A;
            public int B;
            public int C;
        }
 
        class TestClassXY
        {
            public int X;
            public int Y;
        }
 
        class TestMetaClass
        {
            public int NotUsed;
            public string Term;
        }
 
        [Fact]
        public void TestWorking()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var est = new ColumnCopyingEstimator(env, new[] { ("D", "A"), ("E", "B"), ("F", "A") });
            var transformer = est.Fit(dataView);
            var result = transformer.Transform(dataView);
            ValidateCopyColumnTransformer(result);
        }
 
        [Fact]
        public void TestBadOriginalSchema()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var est = new ColumnCopyingEstimator(env, new[] { ("A", "D"), ("E", "B") });
            try
            {
                var transformer = est.Fit(dataView);
                Assert.False(true);
            }
            catch
            {
            }
        }
 
        [Fact]
        public void TestBadTransformSchema()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var xydata = new[] { new TestClassXY() { X = 10, Y = 100 }, new TestClassXY() { X = -1, Y = -100 } };
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var xyDataView = env.Data.LoadFromEnumerable(xydata);
            var est = new ColumnCopyingEstimator(env, new[] { ("D", "A"), ("E", "B"), ("F", "A") });
            var transformer = est.Fit(dataView);
            try
            {
                var result = transformer.Transform(xyDataView);
                Assert.False(true);
            }
            catch
            {
            }
        }
 
        [Fact]
        public void TestSavingAndLoading()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var est = new ColumnCopyingEstimator(env, new[] { ("D", "A"), ("E", "B"), ("F", "A") });
            var transformer = est.Fit(dataView);
            using (var ms = new MemoryStream())
            {
                env.Model.Save(transformer, null, ms);
                ms.Position = 0;
                var loadedTransformer = env.Model.Load(ms, out var schema);
                var result = loadedTransformer.Transform(dataView);
                ValidateCopyColumnTransformer(result);
            }
        }
 
        [Fact]
        public void TestOldSavingAndLoading()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var est = new ColumnCopyingEstimator(env, new[] { ("D", "A"), ("E", "B"), ("F", "A") });
            var transformer = est.Fit(dataView);
            var result = transformer.Transform(dataView);
            var resultRoles = new RoleMappedData(result);
            using (var ms = new MemoryStream())
            {
                TrainUtils.SaveModel(env, ((IHostEnvironment)env).Start("saving"), ms, null, resultRoles);
                ms.Position = 0;
                var loadedView = ModelFileUtils.LoadTransforms(env, dataView, ms);
                ValidateCopyColumnTransformer(loadedView);
            }
        }
 
        [Fact]
        public void TestMetadataCopy()
        {
            var data = new[] { new TestMetaClass() { Term = "A", NotUsed = 1 }, new TestMetaClass() { Term = "B" }, new TestMetaClass() { Term = "C" } };
            var env = new MLContext(1);
            var dataView = env.Data.LoadFromEnumerable(data);
            var term = ValueToKeyMappingTransformer.Create(env, new ValueToKeyMappingTransformer.Options()
            {
                Columns = new[] { new ValueToKeyMappingTransformer.Column() { Source = "Term", Name = "T" } }
            }, dataView);
            var est = new ColumnCopyingEstimator(env, "T1", "T");
            var transformer = est.Fit(term);
            var result = transformer.Transform(term);
            result.Schema.TryGetColumnIndex("T", out int termIndex);
            result.Schema.TryGetColumnIndex("T1", out int copyIndex);
            var names1 = default(VBuffer<ReadOnlyMemory<char>>);
            var names2 = default(VBuffer<ReadOnlyMemory<char>>);
            var type1 = result.Schema[termIndex].Type;
            var itemType1 = (type1 as VectorDataViewType)?.ItemType ?? type1;
            var key = itemType1 as KeyDataViewType;
            Assert.NotNull(key);
            Assert.InRange<ulong>(key.Count, 0, int.MaxValue);
            int size = (int)key.Count;
            var type2 = result.Schema[copyIndex].Type;
            result.Schema[termIndex].GetKeyValues(ref names1);
            result.Schema[copyIndex].GetKeyValues(ref names2);
            Assert.True(TestCommon.CompareVec(in names1, in names2, size, (a, b) => a.Span.SequenceEqual(b.Span)));
        }
 
        [Fact]
        public void TestCommandLine()
        {
            Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=copy{col=B:A} in=f:\1.txt" }));
        }
 
        private void ValidateCopyColumnTransformer(IDataView result)
        {
            using (var cursor = result.GetRowCursorForAllColumns())
            {
                int avalue = 0;
                int bvalue = 0;
                int dvalue = 0;
                int evalue = 0;
                int fvalue = 0;
                var aGetter = cursor.GetGetter<int>(cursor.Schema[0]);
                var bGetter = cursor.GetGetter<int>(cursor.Schema[1]);
                var dGetter = cursor.GetGetter<int>(cursor.Schema[3]);
                var eGetter = cursor.GetGetter<int>(cursor.Schema[4]);
                var fGetter = cursor.GetGetter<int>(cursor.Schema[5]);
                while (cursor.MoveNext())
                {
                    aGetter(ref avalue);
                    bGetter(ref bvalue);
                    dGetter(ref dvalue);
                    eGetter(ref evalue);
                    fGetter(ref fvalue);
                    Assert.Equal(avalue, dvalue);
                    Assert.Equal(bvalue, evalue);
                    Assert.Equal(avalue, fvalue);
                }
            }
        }
    }
}