|
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Xunit;
namespace Microsoft.ML.Scenarios
{
public partial class ScenariosTests
{
[Fact]
public void TrainAndPredictIrisModelWithStringLabelTest()
{
var mlContext = new MLContext(seed: 1);
var reader = mlContext.Data.CreateTextLoader(columns: new[]
{
new TextLoader.Column("SepalLength", DataKind.Single, 0),
new TextLoader.Column("SepalWidth", DataKind.Single, 1),
new TextLoader.Column("PetalLength", DataKind.Single, 2),
new TextLoader.Column("PetalWidth", DataKind.Single, 3),
new TextLoader.Column("IrisPlantType", DataKind.String, 4),
},
separatorChar: ','
);
// Read training and test data sets
string dataPath = GetDataPath("iris.data");
string testDataPath = dataPath;
var trainData = reader.Load(dataPath);
var testData = reader.Load(testDataPath);
// Create Estimator
var pipe = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "IrisPlantType"), TransformerScope.TrainTest)
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(
new SdcaMaximumEntropyMulticlassTrainer.Options { NumberOfThreads = 1 }))
.Append(mlContext.Transforms.Conversion.MapKeyToValue("Plant", "PredictedLabel"));
// Train the pipeline
var trainedModel = pipe.Fit(trainData);
// Make predictions
var predictFunction = mlContext.Model.CreatePredictionEngine<IrisDataWithStringLabel, IrisPredictionWithStringLabel>(trainedModel);
IrisPredictionWithStringLabel prediction = predictFunction.Predict(new IrisDataWithStringLabel()
{
SepalLength = 5.1f,
SepalWidth = 3.3f,
PetalLength = 1.6f,
PetalWidth = 0.2f,
});
Assert.Equal(1d, prediction.PredictedScores[0], 0.01);
Assert.Equal(0d, prediction.PredictedScores[1], 0.01);
Assert.Equal(0d, prediction.PredictedScores[2], 0.01);
Assert.True(prediction.PredictedPlant == "Iris-setosa");
prediction = predictFunction.Predict(new IrisDataWithStringLabel()
{
SepalLength = 6.4f,
SepalWidth = 3.1f,
PetalLength = 5.5f,
PetalWidth = 2.2f,
});
Assert.Equal(0d, prediction.PredictedScores[0], 0.01);
Assert.Equal(0d, prediction.PredictedScores[1], 0.01);
Assert.Equal(1d, prediction.PredictedScores[2], 0.01);
Assert.True(prediction.PredictedPlant == "Iris-virginica");
prediction = predictFunction.Predict(new IrisDataWithStringLabel()
{
SepalLength = 4.4f,
SepalWidth = 3.1f,
PetalLength = 2.5f,
PetalWidth = 1.2f,
});
Assert.Equal(.2, prediction.PredictedScores[0], 0.1);
Assert.Equal(.8, prediction.PredictedScores[1], 0.1);
Assert.Equal(0d, prediction.PredictedScores[2], 0.01);
Assert.True(prediction.PredictedPlant == "Iris-versicolor");
// Evaluate the trained pipeline
var predicted = trainedModel.Transform(testData, TransformerScope.Everything);
var metrics = mlContext.MulticlassClassification.Evaluate(predicted, topKPredictionCount: 3);
Assert.Equal(.98, metrics.MacroAccuracy);
Assert.Equal(.98, metrics.MicroAccuracy, 0.01);
Assert.Equal(.06, metrics.LogLoss, 0.01);
Assert.InRange(metrics.LogLossReduction, 0.94, 0.96);
Assert.Equal(1, metrics.TopKAccuracy);
Assert.Equal(3, metrics.PerClassLogLoss.Count);
Assert.Equal(0d, metrics.PerClassLogLoss[0], 0.1);
Assert.Equal(.1, metrics.PerClassLogLoss[1], 0.1);
Assert.Equal(.1, metrics.PerClassLogLoss[2], 0.1);
}
private class IrisDataWithStringLabel
{
[LoadColumn(0)]
public float SepalLength;
[LoadColumn(1)]
public float SepalWidth;
[LoadColumn(2)]
public float PetalLength;
[LoadColumn(3)]
public float PetalWidth;
[LoadColumn(4)]
public string IrisPlantType { get; set; }
}
private class IrisPredictionWithStringLabel
{
[ColumnName("Score")]
public float[] PredictedScores { get; set; }
[ColumnName("Plant")]
public string PredictedPlant { get; set; }
}
}
}
|