File: Scenarios\OvaTest.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.FastTree;
using Xunit;
 
namespace Microsoft.ML.Scenarios
{
    public partial class ScenariosTests
    {
        [Fact]
        public void OvaLogisticRegression()
        {
            string dataPath = GetDataPath("iris.txt");
 
            // Create a new context for ML.NET operations. It can be used for exception tracking and logging, 
            // as a catalog of available operations and as the source of randomness.
            var mlContext = new MLContext(seed: 1);
            var reader = new TextLoader(mlContext, new TextLoader.Options()
            {
                Columns = new[]
                        {
                            new TextLoader.Column("Label", DataKind.Single, 0),
                            new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }),
                        }
            });
 
            var textData = reader.Load(GetDataPath(dataPath));
            var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label")
                .Fit(textData).Transform(textData));
 
            // Pipeline
            var logReg = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression();
            var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(logReg, useProbabilities: false);
 
            var model = pipeline.Fit(data);
            var predictions = model.Transform(data);
 
            // Metrics
            var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
            Assert.True(metrics.MicroAccuracy > 0.94);
        }
 
        [Fact]
        public void OvaAveragedPerceptron()
        {
            string dataPath = GetDataPath("iris.txt");
 
            // Create a new context for ML.NET operations. It can be used for exception tracking and logging, 
            // as a catalog of available operations and as the source of randomness.
            var mlContext = new MLContext(seed: 1);
            var reader = new TextLoader(mlContext, new TextLoader.Options()
            {
                Columns = new[]
                        {
                            new TextLoader.Column("Label", DataKind.Single, 0),
                            new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }),
                        }
            });
 
            // Data
            var textData = reader.Load(GetDataPath(dataPath));
            var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label")
                .Fit(textData).Transform(textData));
 
            // Pipeline
            var ap = mlContext.BinaryClassification.Trainers.AveragedPerceptron(
                    new AveragedPerceptronTrainer.Options { Shuffle = true });
 
            var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(ap, useProbabilities: false);
 
            var model = pipeline.Fit(data);
            var predictions = model.Transform(data);
 
            // Metrics
            var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
            Assert.True(metrics.MicroAccuracy > 0.66);
        }
 
        [Fact]
        public void OvaFastTree()
        {
            string dataPath = GetDataPath("iris.txt");
 
            // Create a new context for ML.NET operations. It can be used for exception tracking and logging, 
            // as a catalog of available operations and as the source of randomness.
            var mlContext = new MLContext(seed: 1);
            var reader = new TextLoader(mlContext, new TextLoader.Options()
            {
                Columns = new[]
                        {
                            new TextLoader.Column("Label", DataKind.Single, 0),
                            new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }),
                        }
            });
 
            // Data
            var textData = reader.Load(GetDataPath(dataPath));
            var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label")
                .Fit(textData).Transform(textData));
 
            // Pipeline
            var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(
                mlContext.BinaryClassification.Trainers.FastTree(new FastTreeBinaryTrainer.Options { NumberOfThreads = 1 }),
                useProbabilities: false);
 
            var model = pipeline.Fit(data);
            var predictions = model.Transform(data);
 
            // Metrics
            var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
            Assert.True(metrics.MicroAccuracy > 0.99);
        }
 
        [Fact]
        public void OvaLinearSvm()
        {
            string dataPath = GetDataPath("iris.txt");
 
            // Create a new context for ML.NET operations. It can be used for exception tracking and logging, 
            // as a catalog of available operations and as the source of randomness.
            var mlContext = new MLContext(seed: 1);
            var reader = new TextLoader(mlContext, new TextLoader.Options()
            {
                Columns = new[]
                        {
                            new TextLoader.Column("Label", DataKind.Single, 0),
                            new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }),
                        }
            });
            // Data
            var textData = reader.Load(GetDataPath(dataPath));
            var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label")
                .Fit(textData).Transform(textData));
 
            // Pipeline
            var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(
                mlContext.BinaryClassification.Trainers.LinearSvm(new LinearSvmTrainer.Options { NumberOfIterations = 100 }),
                useProbabilities: false);
 
            var model = pipeline.Fit(data);
            var predictions = model.Transform(data);
 
            // Metrics
            var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
            Assert.True(metrics.MicroAccuracy > 0.83);
        }
    }
}