File: TrainerEstimators\FAFMEstimator.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFramework.Attributes;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Trainers;
using Xunit;
 
namespace Microsoft.ML.Tests.TrainerEstimators
{
    public partial class TrainerEstimators : TestDataPipeBase
    {
        [FieldAwareFactorizationMachineFact]
        public void FfmBinaryClassificationWithoutArguments()
        {
            var mlContext = new MLContext(seed: 0);
            var data = GenerateFfmSamples(500);
            var dataView = mlContext.Data.LoadFromEnumerable(data);
 
            var pipeline = mlContext.Transforms.CopyColumns(DefaultColumnNames.Features, nameof(FfmExample.Field0))
                .Append(mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine());
 
            var model = pipeline.Fit(dataView);
            var prediction = model.Transform(dataView);
 
            var metrics = mlContext.BinaryClassification.Evaluate(prediction);
 
            // Run a sanity check against a few of the metrics.
            Assert.InRange(metrics.Accuracy, 0.6, 1);
            Assert.InRange(metrics.AreaUnderRocCurve, 0.7, 1);
            Assert.InRange(metrics.AreaUnderPrecisionRecallCurve, 0.65, 1);
        }
 
        [FieldAwareFactorizationMachineFact]
        public void FfmBinaryClassificationWithAdvancedArguments()
        {
            var mlContext = new MLContext(seed: 0);
            var data = GenerateFfmSamples(500);
            var dataView = mlContext.Data.LoadFromEnumerable(data);
 
            var ffmArgs = new FieldAwareFactorizationMachineTrainer.Options();
 
            // Customized the field names.
            ffmArgs.FeatureColumnName = nameof(FfmExample.Field0); // First field.
            ffmArgs.ExtraFeatureColumns = new[] { nameof(FfmExample.Field1), nameof(FfmExample.Field2) };
 
            var pipeline = mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine(ffmArgs);
 
            var model = pipeline.Fit(dataView);
            var prediction = model.Transform(dataView);
 
            var metrics = mlContext.BinaryClassification.Evaluate(prediction);
 
            // Run a sanity check against a few of the metrics.
            Assert.InRange(metrics.Accuracy, 0.9, 1);
            Assert.InRange(metrics.AreaUnderRocCurve, 0.9, 1);
            Assert.InRange(metrics.AreaUnderPrecisionRecallCurve, 0.9, 1);
        }
 
        [FieldAwareFactorizationMachineFact]
        public void FieldAwareFactorizationMachine_Estimator()
        {
            var data = new TextLoader(Env, GetFafmBCLoaderArgs())
                    .Load(GetDataPath(TestDatasets.breastCancer.trainFilename));
 
            var ffmArgs = new FieldAwareFactorizationMachineTrainer.Options
            {
                FeatureColumnName = "Feature1", // Features from the 1st field.
                ExtraFeatureColumns = new[] { "Feature2", "Feature3", "Feature4" }, // 2nd field's feature column, 3rd field's feature column, 4th field's feature column.
                Shuffle = false,
                NumberOfIterations = 3,
                LatentDimension = 7,
            };
 
            var est = ML.BinaryClassification.Trainers.FieldAwareFactorizationMachine(ffmArgs);
 
            TestEstimatorCore(est, data);
            var model = est.Fit(data);
            var anotherModel = est.Fit(data, data, model.Model);
 
            Done();
        }
 
        private TextLoader.Options GetFafmBCLoaderArgs()
        {
            return new TextLoader.Options()
            {
                Separator = "\t",
                HasHeader = false,
                Columns = new[]
                {
                    new TextLoader.Column("Feature1", DataKind.Single, new [] { new TextLoader.Range(1, 2) }),
                    new TextLoader.Column("Feature2", DataKind.Single, new [] { new TextLoader.Range(3, 4) }),
                    new TextLoader.Column("Feature3", DataKind.Single, new [] { new TextLoader.Range(5, 6) }),
                    new TextLoader.Column("Feature4", DataKind.Single, new [] { new TextLoader.Range(7, 9) }),
                    new TextLoader.Column("Label", DataKind.Boolean, 0)
                }
            };
        }
 
        private const int _simpleBinaryClassSampleFeatureLength = 10;
 
        private class FfmExample
        {
            public bool Label;
 
            [VectorType(_simpleBinaryClassSampleFeatureLength)]
            public float[] Field0;
 
            [VectorType(_simpleBinaryClassSampleFeatureLength)]
            public float[] Field1;
 
            [VectorType(_simpleBinaryClassSampleFeatureLength)]
            public float[] Field2;
        }
 
        private static IEnumerable<FfmExample> GenerateFfmSamples(int exampleCount)
        {
            var rnd = new Random(0);
            var data = new List<FfmExample>();
            for (int i = 0; i < exampleCount; ++i)
            {
                // Initialize an example with a random label and an empty feature vector.
                var sample = new FfmExample()
                {
                    Label = rnd.Next() % 2 == 0,
                    Field0 = new float[_simpleBinaryClassSampleFeatureLength],
                    Field1 = new float[_simpleBinaryClassSampleFeatureLength],
                    Field2 = new float[_simpleBinaryClassSampleFeatureLength]
                };
                // Fill feature vector according the assigned label.
                for (int j = 0; j < 10; ++j)
                {
                    var value0 = (float)rnd.NextDouble();
                    // Positive class gets larger feature value.
                    if (sample.Label)
                        value0 += 0.2f;
                    sample.Field0[j] = value0;
 
                    var value1 = (float)rnd.NextDouble();
                    // Positive class gets smaller feature value.
                    if (sample.Label)
                        value1 -= 0.2f;
                    sample.Field1[j] = value1;
 
                    var value2 = (float)rnd.NextDouble();
                    // Positive class gets larger feature value.
                    if (sample.Label)
                        value2 += 0.8f;
                    sample.Field2[j] = value2;
                }
 
                data.Add(sample);
            }
            return data;
        }
    }
}