File: TermEstimatorTests.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Model;
using Microsoft.ML.RunTests;
using Microsoft.ML.Tools;
using Microsoft.ML.Transforms;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.Tests
{
    public class TermEstimatorTests : TestDataPipeBase
    {
        public TermEstimatorTests(ITestOutputHelper output) : base(output)
        {
        }
 
        class TestClass
        {
            public int A;
            public int B;
            public int C;
        }
 
        class TestClassXY
        {
            public int X;
            public int Y;
        }
 
        class TestClassDifferentTypes
        {
            public string A;
            public string B;
            public string C;
        }
 
 
        class TestMetaClass
        {
            public int NotUsed;
            public string Term;
        }
 
        [Fact]
        public void TestDifferentTypes()
        {
            string dataPath = GetDataPath("adult.tiny.with-schema.txt");
 
            var loader = new TextLoader(ML, new TextLoader.Options
            {
                Columns = new[]{
                    new TextLoader.Column("float1", DataKind.Single, 9),
                    new TextLoader.Column("float4", DataKind.Single, new[]{new TextLoader.Range(9), new TextLoader.Range(10), new TextLoader.Range(11), new TextLoader.Range(12) }),
                    new TextLoader.Column("double1", DataKind.Double, 9),
                    new TextLoader.Column("double4", DataKind.Double, new[]{new TextLoader.Range(9), new TextLoader.Range(10), new TextLoader.Range(11), new TextLoader.Range(12) }),
                    new TextLoader.Column("int1", DataKind.Int32, 9),
                    new TextLoader.Column("text1", DataKind.String, 1),
                    new TextLoader.Column("text2", DataKind.String, new[]{new TextLoader.Range(1), new TextLoader.Range(2)}),
                },
                Separator = "\t",
                HasHeader = true
            }, new MultiFileSource(dataPath));
 
            var pipe = new ValueToKeyMappingEstimator(ML, new[]{
                    new ValueToKeyMappingEstimator.ColumnOptions("TermFloat1", "float1"),
                    new ValueToKeyMappingEstimator.ColumnOptions("TermFloat4", "float4"),
                    new ValueToKeyMappingEstimator.ColumnOptions("TermDouble1", "double1"),
                    new ValueToKeyMappingEstimator.ColumnOptions("TermDouble4", "double4"),
                    new ValueToKeyMappingEstimator.ColumnOptions("TermInt1", "int1"),
                    new ValueToKeyMappingEstimator.ColumnOptions("TermText1", "text1"),
                    new ValueToKeyMappingEstimator.ColumnOptions("TermText2", "text2")
                });
            var data = loader.Load(dataPath);
            data = ML.Data.TakeRows(data, 10);
            var outputPath = GetOutputPath("Term", "Term.tsv");
            using (var ch = Env.Start("save"))
            {
                var saver = new TextSaver(ML, new TextSaver.Arguments { Silent = true });
                using (var fs = File.Create(outputPath))
                    DataSaverUtils.SaveDataView(ch, saver, pipe.Fit(data).Transform(data), fs, keepHidden: true);
            }
 
            CheckEquality("Term", "Term.tsv");
            Done();
        }
 
        [Fact]
        public void TestSimpleCase()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
 
            var xydata = new[] { new TestClassXY() { X = 10, Y = 100 }, new TestClassXY() { X = -1, Y = -100 } };
            var stringData = new[] { new TestClassDifferentTypes { A = "1", B = "c", C = "b" } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var pipe = new ValueToKeyMappingEstimator(Env, new[]{
                   new ValueToKeyMappingEstimator.ColumnOptions("TermA", "A"),
                   new ValueToKeyMappingEstimator.ColumnOptions("TermB", "B"),
                   new ValueToKeyMappingEstimator.ColumnOptions("TermC", "C")
                });
            var invalidData = ML.Data.LoadFromEnumerable(xydata);
            var validFitNotValidTransformData = ML.Data.LoadFromEnumerable(stringData);
            TestEstimatorCore(pipe, dataView, null, invalidData, validFitNotValidTransformData);
        }
 
        [Fact]
        public void TestOldSavingAndLoading()
        {
            var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var est = new ValueToKeyMappingEstimator(Env, new[]{
                    new ValueToKeyMappingEstimator.ColumnOptions("TermA", "A"),
                    new ValueToKeyMappingEstimator.ColumnOptions("TermB", "B"),
                    new ValueToKeyMappingEstimator.ColumnOptions("TermC", "C")
                });
            var transformer = est.Fit(dataView);
            var result = transformer.Transform(dataView);
            var resultRoles = new RoleMappedData(result);
            using (var ms = new MemoryStream())
            {
                TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles);
                ms.Position = 0;
                var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
                ValidateTermTransformer(loadedView);
            }
        }
 
        [Fact]
        public void TestMetadataCopy()
        {
            var data = new[] { new TestMetaClass() { Term = "A", NotUsed = 1 }, new TestMetaClass() { Term = "B" }, new TestMetaClass() { Term = "C" } };
            var dataView = ML.Data.LoadFromEnumerable(data);
            var termEst = new ValueToKeyMappingEstimator(Env, new[] {
                    new ValueToKeyMappingEstimator.ColumnOptions("T", "Term") });
 
            var termTransformer = termEst.Fit(dataView);
            var result = termTransformer.Transform(dataView);
            result.Schema.TryGetColumnIndex("T", out int termIndex);
            var names1 = default(VBuffer<ReadOnlyMemory<char>>);
            var type1 = result.Schema[termIndex].Type;
            var itemType1 = (type1 as VectorDataViewType)?.ItemType ?? type1;
            result.Schema[termIndex].GetKeyValues(ref names1);
            Assert.True(names1.GetValues().Length > 0);
        }
 
        [Fact]
        public void TestCommandLine()
        {
            Assert.Equal(0, Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Term{col=B:A} in=f:\2.txt" }));
        }
 
        private void ValidateTermTransformer(IDataView result)
        {
            using (var cursor = result.GetRowCursorForAllColumns())
            {
                uint avalue = 0;
                uint bvalue = 0;
                uint cvalue = 0;
 
                var aGetter = cursor.GetGetter<uint>(result.Schema["TermA"]);
                var bGetter = cursor.GetGetter<uint>(result.Schema["TermB"]);
                var cGetter = cursor.GetGetter<uint>(result.Schema["TermC"]);
                uint i = 1;
                while (cursor.MoveNext())
                {
                    aGetter(ref avalue);
                    bGetter(ref bvalue);
                    cGetter(ref cvalue);
                    Assert.Equal(i, avalue);
                    Assert.Equal(i, bvalue);
                    Assert.Equal(i, cvalue);
                    i++;
                }
            }
        }
    }
}