File: MetricStatistics.cs
Web Access
Project: src\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj (Microsoft.ML.Transforms)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System.Collections.Generic;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// The MetricsStatistics class computes summary statistics over multiple observations of a metric.
    /// </summary>
    public sealed class MetricStatistics
    {
        private readonly SummaryStatistics _statistic;
 
        /// <summary>
        /// Get the mean value for the metric.
        /// </summary>
        public double Mean => _statistic.Mean;
 
        /// <summary>
        /// Get the standard deviation for the metric.
        /// </summary>
        public double StandardDeviation => (_statistic.RawCount <= 1) ? 0 : _statistic.SampleStdDev;
 
        /// <summary>
        /// Get the standard error of the mean for the metric.
        /// </summary>
        public double StandardError => (_statistic.RawCount <= 1) ? 0 : _statistic.StandardErrorMean;
 
        /// <summary>
        /// Get the count for the number of samples used. Useful for interpreting
        /// the standard deviation and the standard error and building confidence intervals.
        /// </summary>
        public int Count => (int)_statistic.RawCount;
 
        internal MetricStatistics()
        {
            _statistic = new SummaryStatistics();
        }
 
        /// <summary>
        /// Add another metric to the set of observations.
        /// </summary>
        /// <param name="metric">The metric being accumulated</param>
        internal void Add(double metric)
        {
            _statistic.Add(metric);
        }
    }
 
    internal static class MetricsStatisticsUtils
    {
        public static void AddToEach(IReadOnlyList<double> src, IReadOnlyList<MetricStatistics> dest)
        {
            Contracts.Assert(src.Count == dest.Count);
 
            for (int i = 0; i < dest.Count; i++)
                dest[i].Add(src[i]);
        }
 
        public static MetricStatistics[] InitializeArray(int length)
        {
            var array = new MetricStatistics[length];
            for (int i = 0; i < array.Length; i++)
                array[i] = new MetricStatistics();
 
            return array;
        }
    }
 
    /// <summary>
    /// This interface handles the accumulation of summary statistics over multiple observations of
    /// evaluation metrics.
    /// </summary>
    /// <typeparam name="T">The metric results type, such as <see cref="RegressionMetrics"/>.</typeparam>
    internal interface IMetricsStatistics<T>
    {
        void Add(T metrics);
    }
 
    /// <summary>
    /// The <see cref="RegressionMetricsStatistics"/> class holds summary
    /// statistics over multiple observations of <see cref="RegressionMetrics"/>.
    /// </summary>
    public sealed class RegressionMetricsStatistics : IMetricsStatistics<RegressionMetrics>
    {
        /// <summary>
        /// Summary statistics for <see cref="RegressionMetrics.MeanAbsoluteError"/>.
        /// </summary>
        public MetricStatistics MeanAbsoluteError { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="RegressionMetrics.MeanSquaredError"/>.
        /// </summary>
        public MetricStatistics MeanSquaredError { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="RegressionMetrics.RootMeanSquaredError"/>.
        /// </summary>
        public MetricStatistics RootMeanSquaredError { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="RegressionMetrics.LossFunction"/>.
        /// </summary>
        public MetricStatistics LossFunction { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="RegressionMetrics.RSquared"/>.
        /// </summary>
        public MetricStatistics RSquared { get; }
 
        internal RegressionMetricsStatistics()
        {
            MeanAbsoluteError = new MetricStatistics();
            MeanSquaredError = new MetricStatistics();
            RootMeanSquaredError = new MetricStatistics();
            LossFunction = new MetricStatistics();
            RSquared = new MetricStatistics();
        }
 
        /// <summary>
        /// Add a set of evaluation metrics to the set of observations.
        /// </summary>
        /// <param name="metrics">The observed regression evaluation metric</param>
        void IMetricsStatistics<RegressionMetrics>.Add(RegressionMetrics metrics)
        {
            MeanAbsoluteError.Add(metrics.MeanAbsoluteError);
            MeanSquaredError.Add(metrics.MeanSquaredError);
            RootMeanSquaredError.Add(metrics.RootMeanSquaredError);
            LossFunction.Add(metrics.LossFunction);
            RSquared.Add(metrics.RSquared);
        }
    }
 
    /// <summary>
    /// The <see cref="BinaryClassificationMetricsStatistics"/> class holds summary
    /// statistics over multiple observations of <see cref="BinaryClassificationMetrics"/>.
    /// </summary>
    public sealed class BinaryClassificationMetricsStatistics : IMetricsStatistics<BinaryClassificationMetrics>
    {
        /// <summary>
        /// Summary Statistics for <see cref="BinaryClassificationMetrics.AreaUnderRocCurve"/>.
        /// </summary>
        public MetricStatistics AreaUnderRocCurve { get; }
 
        /// <summary>
        /// Summary Statistics for <see cref="BinaryClassificationMetrics.Accuracy"/>.
        /// </summary>
        public MetricStatistics Accuracy { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="BinaryClassificationMetrics.PositivePrecision"/>.
        /// </summary>
        public MetricStatistics PositivePrecision { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="BinaryClassificationMetrics.PositiveRecall"/>.
        /// </summary>
        public MetricStatistics PositiveRecall { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="BinaryClassificationMetrics.NegativePrecision"/>.
        /// </summary>
        public MetricStatistics NegativePrecision { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="BinaryClassificationMetrics.NegativeRecall"/>.
        /// </summary>
        public MetricStatistics NegativeRecall { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="BinaryClassificationMetrics.F1Score"/>.
        /// </summary>
        public MetricStatistics F1Score { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="BinaryClassificationMetrics.AreaUnderPrecisionRecallCurve"/>.
        /// </summary>
        public MetricStatistics AreaUnderPrecisionRecallCurve { get; }
 
        internal BinaryClassificationMetricsStatistics()
        {
            AreaUnderRocCurve = new MetricStatistics();
            Accuracy = new MetricStatistics();
            PositivePrecision = new MetricStatistics();
            PositiveRecall = new MetricStatistics();
            NegativePrecision = new MetricStatistics();
            NegativeRecall = new MetricStatistics();
            F1Score = new MetricStatistics();
            AreaUnderPrecisionRecallCurve = new MetricStatistics();
        }
 
        /// <summary>
        /// Add a set of evaluation metrics to the set of observations.
        /// </summary>
        /// <param name="metrics">The observed binary classification evaluation metric</param>
        void IMetricsStatistics<BinaryClassificationMetrics>.Add(BinaryClassificationMetrics metrics)
        {
            AreaUnderRocCurve.Add(metrics.AreaUnderRocCurve);
            Accuracy.Add(metrics.Accuracy);
            PositivePrecision.Add(metrics.PositivePrecision);
            PositiveRecall.Add(metrics.PositiveRecall);
            NegativePrecision.Add(metrics.NegativePrecision);
            NegativeRecall.Add(metrics.NegativeRecall);
            F1Score.Add(metrics.F1Score);
            AreaUnderPrecisionRecallCurve.Add(metrics.AreaUnderPrecisionRecallCurve);
        }
    }
 
    /// <summary>
    /// The <see cref="MulticlassClassificationMetricsStatistics"/> class holds summary
    /// statistics over multiple observations of <see cref="MulticlassClassificationMetrics"/>.
    /// </summary>
    public sealed class MulticlassClassificationMetricsStatistics : IMetricsStatistics<MulticlassClassificationMetrics>
    {
        /// <summary>
        /// Summary statistics for <see cref="MulticlassClassificationMetrics.MacroAccuracy"/>.
        /// </summary>
        public MetricStatistics MacroAccuracy { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="MulticlassClassificationMetrics.MicroAccuracy"/>.
        /// </summary>
        public MetricStatistics MicroAccuracy { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="MulticlassClassificationMetrics.LogLoss"/>.
        /// </summary>
        public MetricStatistics LogLoss { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="MulticlassClassificationMetrics.LogLossReduction"/>.
        /// </summary>
        public MetricStatistics LogLossReduction { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="MulticlassClassificationMetrics.TopKAccuracy"/>.
        /// </summary>
        public MetricStatistics TopKAccuracy { get; }
 
        /// <summary>
        /// Summary statistics for <see cref="MulticlassClassificationMetrics.PerClassLogLoss"/>.
        /// </summary>
        public IReadOnlyList<MetricStatistics> PerClassLogLoss { get; private set; }
 
        internal MulticlassClassificationMetricsStatistics()
        {
            MacroAccuracy = new MetricStatistics();
            MicroAccuracy = new MetricStatistics();
            LogLoss = new MetricStatistics();
            LogLossReduction = new MetricStatistics();
            TopKAccuracy = new MetricStatistics();
        }
 
        /// <summary>
        /// Add a set of evaluation metrics to the set of observations.
        /// </summary>
        /// <param name="metrics">The observed binary classification evaluation metric</param>
        void IMetricsStatistics<MulticlassClassificationMetrics>.Add(MulticlassClassificationMetrics metrics)
        {
            MacroAccuracy.Add(metrics.MacroAccuracy);
            MicroAccuracy.Add(metrics.MicroAccuracy);
            LogLoss.Add(metrics.LogLoss);
            LogLossReduction.Add(metrics.LogLossReduction);
            TopKAccuracy.Add(metrics.TopKAccuracy);
 
            if (PerClassLogLoss == null)
                PerClassLogLoss = MetricsStatisticsUtils.InitializeArray(metrics.PerClassLogLoss.Count);
            MetricsStatisticsUtils.AddToEach(metrics.PerClassLogLoss, PerClassLogLoss);
        }
    }
 
    /// <summary>
    /// The <see cref="RankingMetricsStatistics"/> class holds summary
    /// statistics over multiple observations of <see cref="RankingMetrics"/>.
    /// </summary>
    public sealed class RankingMetricsStatistics : IMetricsStatistics<RankingMetrics>
    {
        /// <summary>
        /// Summary statistics for <see cref="RankingMetrics.DiscountedCumulativeGains"/>.
        /// </summary>
        public IReadOnlyList<MetricStatistics> DiscountedCumulativeGains { get; private set; }
 
        /// <summary>
        /// Summary statistics for <see cref="RankingMetrics.NormalizedDiscountedCumulativeGains"/>.
        /// </summary>
        public IReadOnlyList<MetricStatistics> NormalizedDiscountedCumulativeGains { get; private set; }
 
        internal RankingMetricsStatistics()
        {
        }
 
        /// <summary>
        /// Add a set of evaluation metrics to the set of observations.
        /// </summary>
        /// <param name="metrics">The observed regression evaluation metric</param>
        void IMetricsStatistics<RankingMetrics>.Add(RankingMetrics metrics)
        {
            if (DiscountedCumulativeGains == null)
                DiscountedCumulativeGains = MetricsStatisticsUtils.InitializeArray(metrics.DiscountedCumulativeGains.Count);
 
            if (NormalizedDiscountedCumulativeGains == null)
                NormalizedDiscountedCumulativeGains = MetricsStatisticsUtils.InitializeArray(metrics.NormalizedDiscountedCumulativeGains.Count);
 
            MetricsStatisticsUtils.AddToEach(metrics.DiscountedCumulativeGains, DiscountedCumulativeGains);
            MetricsStatisticsUtils.AddToEach(metrics.NormalizedDiscountedCumulativeGains, NormalizedDiscountedCumulativeGains);
        }
    }
}