File: AutoMLExperiment\IMetricManager.cs
Web Access
Project: src\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj (Microsoft.ML.AutoML)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
 
namespace Microsoft.ML.AutoML
{
    /// <summary>
    /// Interface for metric manager.
    /// </summary>
    public interface IMetricManager
    {
        bool IsMaximize { get; }
 
        string MetricName { get; }
    }
 
    public interface IEvaluateMetricManager : IMetricManager
    {
        double Evaluate(MLContext context, IDataView eval);
    }
 
    internal class BinaryMetricManager : IEvaluateMetricManager
    {
        public BinaryMetricManager(BinaryClassificationMetric metric, string labelColumn, string predictedColumn)
        {
            Metric = metric;
            PredictedColumn = predictedColumn;
            LabelColumn = labelColumn;
        }
 
        public BinaryClassificationMetric Metric { get; set; }
 
        public string PredictedColumn { get; set; }
 
        public string LabelColumn { get; set; }
 
        public bool IsMaximize => Metric switch
        {
            BinaryClassificationMetric.Accuracy => true,
            BinaryClassificationMetric.AreaUnderPrecisionRecallCurve => true,
            BinaryClassificationMetric.AreaUnderRocCurve => true,
            BinaryClassificationMetric.PositivePrecision => true,
            BinaryClassificationMetric.NegativePrecision => true,
            BinaryClassificationMetric.NegativeRecall => true,
            BinaryClassificationMetric.PositiveRecall => true,
            BinaryClassificationMetric.F1Score => true,
            _ => throw new NotImplementedException(),
        };
 
        public string MetricName => Metric.ToString();
 
        public double Evaluate(MLContext context, IDataView eval)
        {
            var metric = context.BinaryClassification.EvaluateNonCalibrated(eval, labelColumnName: LabelColumn, predictedLabelColumnName: PredictedColumn);
 
            return Metric switch
            {
                BinaryClassificationMetric.Accuracy => metric.Accuracy,
                BinaryClassificationMetric.AreaUnderPrecisionRecallCurve => metric.AreaUnderPrecisionRecallCurve,
                BinaryClassificationMetric.AreaUnderRocCurve => metric.AreaUnderRocCurve,
                BinaryClassificationMetric.PositivePrecision => metric.PositivePrecision,
                BinaryClassificationMetric.NegativePrecision => metric.NegativePrecision,
                BinaryClassificationMetric.NegativeRecall => metric.NegativeRecall,
                BinaryClassificationMetric.PositiveRecall => metric.PositivePrecision,
                BinaryClassificationMetric.F1Score => metric.F1Score,
                _ => throw new NotImplementedException(),
            };
        }
    }
 
    internal class MultiClassMetricManager : IEvaluateMetricManager
    {
        public MulticlassClassificationMetric Metric { get; set; }
 
        public string PredictedColumn { get; set; }
 
        public string LabelColumn { get; set; }
 
        public bool IsMaximize => Metric switch
        {
            MulticlassClassificationMetric.MacroAccuracy => true,
            MulticlassClassificationMetric.MicroAccuracy => true,
            MulticlassClassificationMetric.LogLoss => false,
            MulticlassClassificationMetric.LogLossReduction => false,
            MulticlassClassificationMetric.TopKAccuracy => true,
            _ => throw new NotImplementedException(),
        };
 
        public string MetricName => Metric.ToString();
 
        public double Evaluate(MLContext context, IDataView eval)
        {
            var metric = context.MulticlassClassification.Evaluate(eval, labelColumnName: LabelColumn, predictedLabelColumnName: PredictedColumn);
 
            return Metric switch
            {
                MulticlassClassificationMetric.MacroAccuracy => metric.MacroAccuracy,
                MulticlassClassificationMetric.MicroAccuracy => metric.MicroAccuracy,
                MulticlassClassificationMetric.LogLoss => metric.LogLoss,
                MulticlassClassificationMetric.LogLossReduction => metric.LogLossReduction,
                MulticlassClassificationMetric.TopKAccuracy => metric.TopKAccuracy,
                _ => throw new NotImplementedException(),
            };
        }
    }
 
    internal class RegressionMetricManager : IEvaluateMetricManager
    {
        public RegressionMetric Metric { get; set; }
 
        public string ScoreColumn { get; set; }
 
        public string LabelColumn { get; set; }
 
        public bool IsMaximize => Metric switch
        {
            RegressionMetric.RSquared => true,
            RegressionMetric.RootMeanSquaredError => false,
            RegressionMetric.MeanSquaredError => false,
            RegressionMetric.MeanAbsoluteError => false,
            _ => throw new NotImplementedException(),
        };
 
        public string MetricName => Metric.ToString();
 
        public double Evaluate(MLContext context, IDataView eval)
        {
            var metric = context.Regression.Evaluate(eval, LabelColumn, ScoreColumn);
 
            return Metric switch
            {
                RegressionMetric.RSquared => metric.RSquared,
                RegressionMetric.RootMeanSquaredError => metric.RootMeanSquaredError,
                RegressionMetric.MeanSquaredError => metric.MeanSquaredError,
                RegressionMetric.MeanAbsoluteError => metric.MeanAbsoluteError,
                _ => throw new NotImplementedException(),
            };
        }
    }
}