File: Evaluators\Metrics\ConfusionMatrix.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.ML.Runtime;
 
namespace Microsoft.ML.Data
{
    /// <summary>
    /// Represents the <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a> of the classification results.
    /// </summary>
    public sealed class ConfusionMatrix
    {
        /// <summary>
        /// The calculated value of <a href="https://en.wikipedia.org/wiki/Precision_and_recall#Precision">precision</a> for each class.
        /// </summary>
        public IReadOnlyList<double> PerClassPrecision { get; }
 
        /// <summary>
        /// The calculated value of <a href="https://en.wikipedia.org/wiki/Precision_and_recall#Recall">recall</a> for each class.
        /// </summary>
        public IReadOnlyList<double> PerClassRecall { get; }
 
        /// <summary>
        /// The confusion matrix counts for the combinations actual class/predicted class.
        /// The actual classes are in the rows of the table (stored in the outer <see cref="IReadOnlyList{T}"/>), and the predicted classes
        ///  in the columns(stored in the inner <see cref="IReadOnlyList{T}"/>).
        /// </summary>
        public IReadOnlyList<IReadOnlyList<double>> Counts { get; }
 
        /// <summary>
        /// The indicators of the predicted classes.
        /// It might be the classes names, or just indices of the predicted classes, if the name mapping is missing.
        /// </summary>
        public int NumberOfClasses { get; }
 
        /// <summary>
        /// The indicators of the predicted classes.
        /// It might be the classes names, or just indices of the predicted classes, if the name mapping is missing.
        /// </summary>
        internal IReadOnlyList<ReadOnlyMemory<char>> PredictedClassesIndicators;
 
        internal readonly bool IsSampled;
        internal readonly bool IsBinary;
 
        private readonly IHost _host;
        private string _formattedConfusionMatrix;
 
        /// <summary>
        /// The confusion matrix as a structured type, built from the counts of the confusion table <see cref="IDataView"/> that the <see cref="BinaryClassifierEvaluator"/> or
        /// the <see cref="MulticlassClassificationEvaluator"/> constructor.
        /// </summary>
        /// <param name="host">The IHost instance. </param>
        /// <param name="precision">The values of precision per class.</param>
        /// <param name="recall">The vales of recall per class.</param>
        /// <param name="confusionTableCounts">The counts of the confusion table. The actual classes values are in the rows of the 2D array,
        /// and the counts of the predicted classes are in the columns.</param>
        /// <param name="labelNames">The predicted classes names, or the indexes of the classes, if the names are missing.</param>
        /// <param name="isSampled">Whether the classes are sampled.</param>
        /// <param name="isBinary">Whether the confusion table is the result of a binary classification. </param>
        internal ConfusionMatrix(IHost host, double[] precision, double[] recall, double[][] confusionTableCounts,
             List<ReadOnlyMemory<char>> labelNames, bool isSampled, bool isBinary)
        {
            Contracts.AssertValue(host);
            _host = host;
 
            _host.AssertNonEmpty(precision);
            _host.AssertNonEmpty(recall);
            _host.AssertNonEmpty(confusionTableCounts);
            _host.AssertNonEmpty(labelNames);
            _host.AssertNonEmpty(precision);
 
            _host.Assert(precision.Length == confusionTableCounts.Length);
            _host.Assert(recall.Length == confusionTableCounts.Length);
            _host.Assert(labelNames.Count == confusionTableCounts.Length);
 
            PerClassPrecision = precision.ToImmutableArray();
            PerClassRecall = recall.ToImmutableArray();
            IsSampled = isSampled;
            IsBinary = isBinary;
            PredictedClassesIndicators = labelNames.AsReadOnly();
 
            NumberOfClasses = confusionTableCounts.Length;
            List<IReadOnlyList<double>> counts = new List<IReadOnlyList<double>>(NumberOfClasses);
 
            for (int i = 0; i < NumberOfClasses; i++)
                counts.Add(confusionTableCounts[i].ToList().AsReadOnly());
 
            Counts = counts.AsReadOnly();
        }
 
        /// <summary>
        /// Returns a human readable representation of the confusion table.
        /// </summary>
        /// <returns></returns>
        public string GetFormattedConfusionTable()
        {
            if (_formattedConfusionMatrix == null)
                _formattedConfusionMatrix = MetricWriter.GetConfusionTableAsString(this, false);
 
            return _formattedConfusionMatrix;
        }
 
        /// <summary>
        /// Gets the confusion table count for the pair <paramref name="predictedClassIndicatorIndex"/>/<paramref name="actualClassIndicatorIndex"/>.
        /// </summary>
        /// <param name="predictedClassIndicatorIndex">The index of the predicted label indicator, in the <see cref="PredictedClassesIndicators"/>.</param>
        /// <param name="actualClassIndicatorIndex">The index of the actual label indicator, in the <see cref="PredictedClassesIndicators"/>.</param>
        /// <returns></returns>
        public double GetCountForClassPair(int predictedClassIndicatorIndex, int actualClassIndicatorIndex)
        {
            _host.CheckParam(predictedClassIndicatorIndex > -1 && predictedClassIndicatorIndex < Counts.Count,
                nameof(predictedClassIndicatorIndex), "Invalid index. Should be non-negative, less than the number of classes.");
            _host.CheckParam(actualClassIndicatorIndex > -1 && actualClassIndicatorIndex < Counts.Count,
                nameof(actualClassIndicatorIndex), "Invalid index. Should be non-negative, less than the number of classes.");
 
            return Counts[actualClassIndicatorIndex][predictedClassIndicatorIndex];
        }
    }
}