File: Transformers\CustomMappingTests.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.Runtime;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Transforms;
using Xunit;
using Xunit.Abstractions;
 
namespace Microsoft.ML.Tests.Transformers
{
    public sealed class CustomMappingTests : TestDataPipeBase
    {
        public CustomMappingTests(ITestOutputHelper helper) : base(helper)
        {
        }
 
        public class MyInput
        {
            public float Float1 { get; set; }
            public float[] Float4 { get; set; }
        }
 
        public class MyOutput
        {
            public string Together { get; set; }
        }
 
        [CustomMappingFactoryAttribute("MyLambda")]
        public class MyLambda : CustomMappingFactory<MyInput, MyOutput>
        {
            public static void MyAction(MyInput input, MyOutput output)
            {
                output.Together = $"{input.Float1} + {string.Join(", ", input.Float4)}";
            }
 
            public override Action<MyInput, MyOutput> GetMapping()
            {
                return MyAction;
            }
        }
 
        [Theory]
        [InlineData(true)]
        [InlineData(false)]
        public void TestCustomTransformer(bool registerAssembly)
        {
            string dataPath = GetDataPath("adult.tiny.with-schema.txt");
            var source = new MultiFileSource(dataPath);
            var loader = ML.Data.CreateTextLoader(new[] {
                    new TextLoader.Column("Float1", DataKind.Single, 9),
                    new TextLoader.Column("Float4", DataKind.Single, new[]{new TextLoader.Range(9), new TextLoader.Range(10), new TextLoader.Range(11), new TextLoader.Range(12) })
            }, hasHeader: true);
 
            var data = loader.Load(source);
 
            IDataView transformedData;
            // We create a temporary environment to instantiate the custom transformer. This is to ensure that we don't need the same
            // environment for saving and loading.
            var tempoEnv = new MLContext(1);
            var customEst = new CustomMappingEstimator<MyInput, MyOutput>(tempoEnv, MyLambda.MyAction, "MyLambda");
 
            // Before 1.5-preview3 it was required to register the assembly. 
            // Now, the assembly information is automatically saved in the model and the assembly is registered
            // when loading.
            // This tests the case that the CustomTransformer still works even if you explicitly register the assembly
            if (registerAssembly)
                ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly);
 
            TestEstimatorCore(customEst, data);
            transformedData = customEst.Fit(data).Transform(data);
 
            var inputs = ML.Data.CreateEnumerable<MyInput>(transformedData, true);
            var outputs = ML.Data.CreateEnumerable<MyOutput>(transformedData, true);
 
            Assert.True(inputs.Zip(outputs, (x, y) => y.Together == $"{x.Float1} + {string.Join(", ", x.Float4)}").All(x => x));
 
            Done();
        }
 
        [Fact]
        public void TestSchemaPropagation()
        {
            string dataPath = GetDataPath("adult.test");
            var source = new MultiFileSource(dataPath);
            var loader = ML.Data.CreateTextLoader(new[] {
                    new TextLoader.Column("Float1", DataKind.Single, 0),
                    new TextLoader.Column("Float4", DataKind.Single, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10) }),
                    new TextLoader.Column("Text1", DataKind.String, 0)
            }, separatorChar: ',', hasHeader: true);
 
            var data = loader.Load(source);
 
            Action<MyInput, MyOutput> mapping = (input, output) => output.Together = input.Float1.ToString();
            var est = ML.Transforms.CustomMapping(mapping, null);
 
            // Make sure schema propagation works for valid data.
            est.GetOutputSchema(SchemaShape.Create(data.Schema));
 
            var badData1 = ML.Transforms.CopyColumns("Text1", "Float1").Fit(data).Transform(data);
            try
            {
                est.GetOutputSchema(SchemaShape.Create(badData1.Schema));
                Assert.True(false);
            }
            catch (Exception) { }
 
            var badData2 = ML.Transforms.SelectColumns(new[] { "Float1" }).Fit(data).Transform(data);
            try
            {
                est.GetOutputSchema(SchemaShape.Create(badData2.Schema));
                Assert.True(false);
            }
            catch (Exception) { }
 
            Done();
        }
 
        public class MyStatefulInput
        {
            public float Value { get; set; }
        }
 
        public class MyState
        {
            public HashSet<float> SeenValues;
        }
 
        public class MyStatefulOutput
        {
            public bool FirstAppearance { get; set; }
        }
 
        [CustomMappingFactoryAttribute(nameof(MyStatefulLambda))]
        public class MyStatefulLambda : StatefulCustomMappingFactory<MyStatefulInput, MyStatefulOutput, MyState>
        {
            public override Action<MyStatefulInput, MyStatefulOutput, MyState> GetMapping()
            {
                return MyStatefulAction;
            }
 
            public override Action<MyState> GetStateInitAction()
            {
                return MyStateInit;
            }
 
            public static void MyStatefulAction(MyStatefulInput input, MyStatefulOutput output, MyState state)
            {
                output.FirstAppearance = !state.SeenValues.Contains(input.Value);
                state.SeenValues.Add(input.Value);
            }
 
            public static void MyStateInit(MyState state)
            {
                state.SeenValues = new HashSet<float>();
            }
        }
 
        [Fact]
        public void TestStatefulCustomMappingTransformer()
        {
            string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
            var source = new MultiFileSource(dataPath);
            var loader = ML.Data.CreateTextLoader(new[] {
                new TextLoader.Column("Features", DataKind.Single, 1, 9),
                new TextLoader.Column("Label", DataKind.String, 0),
                new TextLoader.Column("Value", DataKind.Single, 2),
            });
            var data = loader.Load(source);
 
            // We create a temporary environment to instantiate the custom transformer. This is to ensure that we don't need the same
            // environment for saving and loading.
            var tempoEnv = new MLContext();
            var customEst = tempoEnv.Transforms.StatefulCustomMapping<MyStatefulInput, MyStatefulOutput, MyState>(MyStatefulLambda.MyStatefulAction, MyStatefulLambda.MyStateInit, nameof(MyStatefulLambda));
 
            TestEstimatorCore(customEst, data);
            var transformedData = customEst.Fit(data).Transform(data);
            var outputs = transformedData.GetColumn<bool>(transformedData.Schema[nameof(MyStatefulOutput.FirstAppearance)]);
            Assert.Equal(10, outputs.Count(output => output));
 
            Done();
        }
 
        [Fact]
        public void TestCustomFilter()
        {
            string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
            var source = new MultiFileSource(dataPath);
            var loader = ML.Data.CreateTextLoader(new[] {
                new TextLoader.Column("Float4", DataKind.Single, 1, 4),
                new TextLoader.Column("Float1", DataKind.Single, 2),
            });
            var data = loader.Load(source);
 
            var filteredData = ML.Data.FilterByCustomPredicate<MyInput>(data, input => input.Float1 % 2 == 0);
            Assert.True(filteredData.GetColumn<float>(filteredData.Schema[nameof(MyInput.Float1)]).All(x => x % 2 == 1));
        }
 
        private sealed class MyFilterState
        {
            public int Count { get; set; }
        }
 
        private sealed class MyFilterInput
        {
            public int Counter { get; set; }
            public int Value { get; set; }
        }
 
        [Fact]
        public void TestStatefulCustomFilter()
        {
            var data = ML.Data.LoadFromEnumerable(new[]
            {
                new MyFilterInput() { Counter = 0, Value = 1 },
                new MyFilterInput() { Counter = 1, Value = 1 },
                new MyFilterInput() { Counter = 2, Value = 2 },
                new MyFilterInput() { Counter = 3, Value = 0 },
                new MyFilterInput() { Counter = 4, Value = 2 },
                new MyFilterInput() { Counter = 5, Value = 4 },
                new MyFilterInput() { Counter = 6, Value = 1 },
                new MyFilterInput() { Counter = 7, Value = 1 },
                new MyFilterInput() { Counter = 8, Value = 2 },
            });
 
            var filteredData = ML.Data.FilterByStatefulCustomPredicate<MyFilterInput, MyFilterState>(data,
                (input, state) =>
                {
                    if (state.Count++ % 2 == 0)
                        return input.Value % 2 == 0;
                    else
                        return input.Value % 2 == 1;
                }, state => state.Count = 0);
 
            var values = filteredData.GetColumn<int>(filteredData.Schema[nameof(MyFilterInput.Value)]);
            var counter = filteredData.GetColumn<int>(filteredData.Schema[nameof(MyFilterInput.Counter)]);
            Assert.Equal(new[] { 0, 3, 5, 6 }, counter);
            Assert.Equal(new[] { 1, 0, 4, 1 }, values);
        }
    }
}