File: Transformers\KeyToValueTests.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Transforms;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.Tests.Transformers
{
    public sealed class KeyToValueTests : TestDataPipeBase
    {
        public KeyToValueTests(ITestOutputHelper output) : base(output)
        {
        }
 
        [Fact]
        public void KeyToValueWorkout()
        {
            string dataPath = GetDataPath("iris.txt");
 
            var reader = new TextLoader(Env, new TextLoader.Options
            {
                Columns = new[]
                {
                    new TextLoader.Column("ScalarString", DataKind.String, 1),
                    new TextLoader.Column("VectorString", DataKind.String, new[] {new TextLoader.Range(1, 4) }),
                    new TextLoader.Column("BareKey", DataKind.UInt32, new[] { new TextLoader.Range(0) }, new KeyCount(6))
                }
            });
 
            var data = reader.Load(dataPath);
 
            data = new ValueToKeyMappingEstimator(Env, new[] {
                new ValueToKeyMappingEstimator.ColumnOptions("A", "ScalarString"),
                new ValueToKeyMappingEstimator.ColumnOptions("B", "VectorString") }).Fit(data).Transform(data);
 
            var badData1 = new ColumnCopyingTransformer(Env, ("A", "BareKey")).Transform(data);
            var badData2 = new ColumnCopyingTransformer(Env, ("B", "VectorString")).Transform(data);
 
            var est = new KeyToValueMappingEstimator(Env, ("A_back", "A"), ("B_back", "B"));
            TestEstimatorCore(est, data, invalidInput: badData1);
            TestEstimatorCore(est, data, invalidInput: badData2);
 
 
            var outputPath = GetOutputPath("KeyToValue", "featurized.tsv");
            using (var ch = Env.Start("save"))
            {
                var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true });
                IDataView savedData = est.Fit(data).Transform(data);
                using (var fs = File.Create(outputPath))
                    DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true);
            }
 
            CheckEquality("KeyToValue", "featurized.tsv");
            Done();
        }
 
        [Fact]
        public void KeyToValue()
        {
            string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
            var data = ML.Data.LoadFromTextFile(dataPath, new[] {
                new TextLoader.Column("ScalarString", DataKind.String, 0),
                new TextLoader.Column("VectorString", DataKind.String, 1, 4),
            });
 
            var transformedData = new ValueToKeyMappingEstimator(Env, new[] {
                new ValueToKeyMappingEstimator.ColumnOptions("A", "ScalarString"),
                new ValueToKeyMappingEstimator.ColumnOptions("B", "VectorString") })
                .Fit(data).Transform(data);
 
            var est = ML.Transforms.Conversion.MapKeyToValue("ScalarString", "A")
                .Append(ML.Transforms.Conversion.MapKeyToValue("VectorString", "B"));
 
            TestEstimatorCore(est, transformedData, invalidInput: data);
 
            var data2Transformed = est.Fit(transformedData).Transform(transformedData);
            // Check that term and ToValue are round-trippable.
            var dataLeft = ML.Transforms.SelectColumns(new[] { "ScalarString", "VectorString" }).Fit(data).Transform(data);
            var dataRight = ML.Transforms.SelectColumns(new[] { "ScalarString", "VectorString" }).Fit(data2Transformed).Transform(data2Transformed);
 
            TestCommon.CheckSameSchemas(dataLeft.Schema, dataRight.Schema);
            CheckSameValues(dataLeft, dataRight);
            Done();
        }
    }
}