File: ScenariosWithDirectInstantiation\IrisPlantClassificationTests.cs
Web Access
Project: src\test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj (Microsoft.ML.Tests)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFrameworkCommon;
using Microsoft.ML.Trainers;
using Xunit;
 
namespace Microsoft.ML.Scenarios
{
    public partial class ScenariosTests
    {
        [Fact]
        public void TrainAndPredictIrisModelUsingDirectInstantiationTest()
        {
            var mlContext = new MLContext(seed: 1);
 
            var reader = mlContext.Data.CreateTextLoader(columns: new[]
                {
                    new TextLoader.Column("Label", DataKind.Single, 0),
                    new TextLoader.Column("SepalLength", DataKind.Single, 1),
                    new TextLoader.Column("SepalWidth", DataKind.Single, 2),
                    new TextLoader.Column("PetalLength", DataKind.Single, 3),
                    new TextLoader.Column("PetalWidth", DataKind.Single, 4)
                }
            );
 
            var pipe = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
                .Append(mlContext.Transforms.NormalizeMinMax("Features"))
                .Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
                .AppendCacheCheckpoint(mlContext)
                .Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(
                    new SdcaMaximumEntropyMulticlassTrainer.Options { NumberOfThreads = 1 }));
 
            // Read training and test data sets
            string dataPath = GetDataPath(TestDatasets.iris.trainFilename);
            string testDataPath = dataPath;
            var trainData = reader.Load(dataPath);
            var testData = reader.Load(testDataPath);
 
            // Train the pipeline
            var trainedModel = pipe.Fit(trainData);
 
            // Make prediction and then evaluate the trained pipeline
            var predicted = trainedModel.Transform(testData);
            var metrics = mlContext.MulticlassClassification.Evaluate(predicted);
            CompareMetrics(metrics);
            var predictFunction = mlContext.Model.CreatePredictionEngine<IrisData, IrisPrediction>(trainedModel);
            ComparePredictions(predictFunction);
        }
 
        private void ComparePredictions(PredictionEngine<IrisData, IrisPrediction> model)
        {
            IrisPrediction prediction = model.Predict(new IrisData()
            {
                SepalLength = 5.1f,
                SepalWidth = 3.3f,
                PetalLength = 1.6f,
                PetalWidth = 0.2f,
            });
 
            Assert.Equal(1d, prediction.PredictedLabels[0], 0.01);
            Assert.Equal(0d, prediction.PredictedLabels[1], 0.01);
            Assert.Equal(0d, prediction.PredictedLabels[2], 0.01);
 
            prediction = model.Predict(new IrisData()
            {
                SepalLength = 6.4f,
                SepalWidth = 3.1f,
                PetalLength = 5.5f,
                PetalWidth = 2.2f,
            });
 
            Assert.Equal(0d, prediction.PredictedLabels[0], 0.01);
            Assert.Equal(0d, prediction.PredictedLabels[1], 0.01);
            Assert.Equal(1d, prediction.PredictedLabels[2], 0.01);
 
            prediction = model.Predict(new IrisData()
            {
                SepalLength = 4.4f,
                SepalWidth = 3.1f,
                PetalLength = 2.5f,
                PetalWidth = 1.2f,
            });
 
            Assert.Equal(.2, prediction.PredictedLabels[0], 0.1);
            Assert.Equal(.8, prediction.PredictedLabels[1], 0.1);
            Assert.Equal(0d, prediction.PredictedLabels[2], 0.01);
        }
 
        private void CompareMetrics(MulticlassClassificationMetrics metrics)
        {
            Assert.Equal(.98, metrics.MacroAccuracy);
            Assert.Equal(.98, metrics.MicroAccuracy, 0.01);
            Assert.InRange(metrics.LogLoss, .05, .06);
            Assert.InRange(metrics.LogLossReduction, 0.94, 0.96);
 
            Assert.Equal(3, metrics.PerClassLogLoss.Count);
            Assert.Equal(0d, metrics.PerClassLogLoss[0], 0.1);
            Assert.Equal(.1, metrics.PerClassLogLoss[1], 0.1);
            Assert.Equal(.1, metrics.PerClassLogLoss[2], 0.1);
        }
    }
}