|
using System;
using System.Collections.Generic;
using Microsoft.ML;
namespace Samples.Dynamic
{
public static class StatefulCustomMapping
{
public static void Example()
{
// Create a new ML context, for ML.NET operations. It can be used for
// exception tracking and logging, as well as the source of randomness.
var mlContext = new MLContext();
// Get a small dataset as an IEnumerable and convert it to an IDataView.
var samples = new List<InputData>
{
new InputData { Sign = true },
new InputData { Sign = true },
new InputData { Sign = false },
new InputData { Sign = true },
new InputData { Sign = false },
new InputData { Sign = true },
new InputData { Sign = false },
new InputData { Sign = false },
new InputData { Sign = false },
};
var data = mlContext.Data.LoadFromEnumerable(samples);
// We define the stateful custom mapping between input and output rows that will
// be applied by the transformation.
Action<InputData, CustomMappingOutput, State> mapping =
(input, output, state) =>
{
var fib = state.Prev2 + state.Prev;
output.SignedFibonacci = input.Sign ? fib : -fib;
state.Prev2 = state.Prev;
state.Prev = fib;
};
// Define the way to initialize the State object.
Action<State> init =
state =>
{
state.Prev2 = 0;
state.Prev = 1;
};
// Custom transformations can be used to transform data directly, or as
// part of a pipeline of estimators. Note: If contractName is null in
// the StatefulCustomMapping estimator, any pipeline of estimators containing
// it, cannot be saved and loaded back.
var pipeline = mlContext.Transforms.StatefulCustomMapping(mapping, init, contractName: null);
// Now we can transform the data and look at the output to confirm the
// behavior of the estimator. This operation doesn't actually evaluate
// data until we read the data below.
var transformer = pipeline.Fit(data);
var transformedData = transformer.Transform(data);
var dataEnumerable = mlContext.Data.CreateEnumerable<TransformedData>(
transformedData, reuseRowObject: true);
Console.WriteLine("Sign\t SignedFibonacci");
foreach (var row in dataEnumerable)
Console.WriteLine($"{row.Sign}\t {row.SignedFibonacci}");
// Expected output:
// Sign SignedFibonacci
// True 1
// True 2
// False -3
// True 5
// False -8
// True 13
// False -21
// False -34
// False -55
}
// Defines only the column to be generated by the custom mapping
// transformation in addition to the columns already present.
private class CustomMappingOutput
{
public long SignedFibonacci { get; set; }
}
// Defines the schema of the input data.
private class InputData
{
public bool Sign { get; set; }
}
private class State
{
public long Prev { get; set; }
public long Prev2 { get; set; }
}
// Defines the schema of the transformed data, which includes the new column
// SignedFibonacci.
private class TransformedData : InputData
{
public long SignedFibonacci { get; set; }
}
}
}
|