File: Dynamic\Transforms\StatefulCustomMapping.cs
Web Access
Project: src\docs\samples\Microsoft.ML.Samples\Microsoft.ML.Samples.csproj (Microsoft.ML.Samples)
using System;
using System.Collections.Generic;
using Microsoft.ML;
 
namespace Samples.Dynamic
{
    public static class StatefulCustomMapping
    {
        public static void Example()
        {
            // Create a new ML context, for ML.NET operations. It can be used for
            // exception tracking and logging, as well as the source of randomness.
            var mlContext = new MLContext();
 
            // Get a small dataset as an IEnumerable and convert it to an IDataView.
            var samples = new List<InputData>
            {
                new InputData { Sign = true },
                new InputData { Sign = true },
                new InputData { Sign = false },
                new InputData { Sign = true },
                new InputData { Sign = false },
                new InputData { Sign = true },
                new InputData { Sign = false },
                new InputData { Sign = false },
                new InputData { Sign = false },
            };
            var data = mlContext.Data.LoadFromEnumerable(samples);
 
            // We define the stateful custom mapping between input and output rows that will
            // be applied by the transformation.
            Action<InputData, CustomMappingOutput, State> mapping =
                (input, output, state) =>
                {
                    var fib = state.Prev2 + state.Prev;
                    output.SignedFibonacci = input.Sign ? fib : -fib;
                    state.Prev2 = state.Prev;
                    state.Prev = fib;
                };
 
            // Define the way to initialize the State object.
            Action<State> init =
                state =>
                {
                    state.Prev2 = 0;
                    state.Prev = 1;
                };
 
            // Custom transformations can be used to transform data directly, or as
            // part of a pipeline of estimators. Note: If contractName is null in
            // the StatefulCustomMapping estimator, any pipeline of estimators containing
            // it, cannot be saved and loaded back. 
            var pipeline = mlContext.Transforms.StatefulCustomMapping(mapping, init, contractName: null);
 
            // Now we can transform the data and look at the output to confirm the
            // behavior of the estimator. This operation doesn't actually evaluate
            // data until we read the data below.
            var transformer = pipeline.Fit(data);
            var transformedData = transformer.Transform(data);
 
            var dataEnumerable = mlContext.Data.CreateEnumerable<TransformedData>(
                transformedData, reuseRowObject: true);
 
            Console.WriteLine("Sign\t SignedFibonacci");
            foreach (var row in dataEnumerable)
                Console.WriteLine($"{row.Sign}\t {row.SignedFibonacci}");
 
            // Expected output:
            // Sign      SignedFibonacci
            // True      1
            // True      2
            // False     -3
            // True      5
            // False     -8
            // True      13
            // False     -21
            // False     -34
            // False     -55
        }
 
        // Defines only the column to be generated by the custom mapping
        // transformation in addition to the columns already present.
        private class CustomMappingOutput
        {
            public long SignedFibonacci { get; set; }
        }
 
        // Defines the schema of the input data.
        private class InputData
        {
            public bool Sign { get; set; }
        }
 
        private class State
        {
            public long Prev { get; set; }
            public long Prev2 { get; set; }
        }
 
        // Defines the schema of the transformed data, which includes the new column
        // SignedFibonacci.
        private class TransformedData : InputData
        {
            public long SignedFibonacci { get; set; }
        }
    }
}