File: Evaluators\BinaryClassifierEvaluator.cs
Web Access
Project: src\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj (Microsoft.ML.Data)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
 
[assembly: LoadableClass(typeof(BinaryClassifierEvaluator), typeof(BinaryClassifierEvaluator), typeof(BinaryClassifierEvaluator.Arguments), typeof(SignatureEvaluator),
    "Binary Classifier Evaluator", BinaryClassifierEvaluator.LoadName, "BinaryClassifier", "Binary", "bin")]
 
[assembly: LoadableClass(typeof(BinaryClassifierMamlEvaluator), typeof(BinaryClassifierMamlEvaluator), typeof(BinaryClassifierMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator),
    "Binary Classifier Evaluator", BinaryClassifierEvaluator.LoadName, "BinaryClassifier", "Binary", "bin")]
 
// This is for deserialization from a binary model file.
[assembly: LoadableClass(typeof(BinaryPerInstanceEvaluator), null, typeof(SignatureLoadRowMapper),
    "", BinaryPerInstanceEvaluator.LoaderSignature)]
 
[assembly: LoadableClass(typeof(void), typeof(Evaluate), null, typeof(SignatureEntryPointModule), "Evaluators")]
 
namespace Microsoft.ML.Data
{
    [BestFriend]
    internal sealed class BinaryClassifierEvaluator : RowToRowEvaluatorBase<BinaryClassifierEvaluator.Aggregator>
    {
        public sealed class Arguments
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Probability value for classification thresholding")]
            public Single Threshold;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Use raw score value instead of probability for classification thresholding", ShortName = "useRawScore")]
            public bool UseRawScoreThreshold = true;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of samples to use for p/r curve generation. Specify 0 for no p/r curve generation", ShortName = "numpr")]
            public int NumRocExamples;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of samples to use for AUC calculation. If 0, AUC is not computed. If -1, the whole dataset is used", ShortName = "numauc")]
            public int MaxAucExamples = -1;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of samples to use for AUPRC calculation. Specify 0 for no AUPRC calculation", ShortName = "numauprc")]
            public int NumAuPrcExamples = 100000;
        }
 
        public const string LoadName = "BinaryClassifierEvaluator";
 
        // Overall metrics.
        public const string Accuracy = "Accuracy";
        public const string PosPrecName = "Positive precision";
        public const string PosRecallName = "Positive recall";
        public const string NegPrecName = "Negative precision";
        public const string NegRecallName = "Negative recall";
        public const string Auc = "AUC";
        public const string LogLoss = "Log-loss";
        public const string LogLossReduction = "Log-loss reduction";
        public const string Entropy = "Test-set entropy (prior Log-Loss/instance)";
        public const string F1 = "F1 Score";
        public const string AuPrc = "AUPRC";
 
        public enum Metrics
        {
            [EnumValueDisplay(BinaryClassifierEvaluator.Accuracy)]
            Accuracy,
            [EnumValueDisplay(BinaryClassifierEvaluator.PosPrecName)]
            PosPrecName,
            [EnumValueDisplay(BinaryClassifierEvaluator.PosRecallName)]
            PosRecallName,
            [EnumValueDisplay(BinaryClassifierEvaluator.NegPrecName)]
            NegPrecName,
            [EnumValueDisplay(BinaryClassifierEvaluator.NegRecallName)]
            NegRecallName,
            [EnumValueDisplay(BinaryClassifierEvaluator.Auc)]
            Auc,
            [EnumValueDisplay(BinaryClassifierEvaluator.LogLoss)]
            LogLoss,
            [EnumValueDisplay(BinaryClassifierEvaluator.LogLossReduction)]
            LogLossReduction,
            [EnumValueDisplay(BinaryClassifierEvaluator.F1)]
            F1,
            [EnumValueDisplay(BinaryClassifierEvaluator.AuPrc)]
            AuPrc,
        }
 
        /// <summary>
        /// Binary classification evaluator outputs a data view with this name, which contains the p/r data.
        /// It contains the columns listed below, and in case data also contains a weight column, it contains
        /// also columns for the weighted values.
        /// and false positive rate.
        /// </summary>
        public const string PrCurve = "PrCurve";
 
        // Column names for the p/r data view.
        public const string Precision = "Precision";
        public const string Recall = "Recall";
        public const string FalsePositiveRate = "FPR";
        public const string Threshold = "Threshold";
 
        private readonly Single _threshold;
        private readonly bool _useRaw;
        private readonly int _prCount;
        private readonly int _aucCount;
        private readonly int _auPrcCount;
 
        public BinaryClassifierEvaluator(IHostEnvironment env, Arguments args)
            : base(env, LoadName)
        {
            var host = Host.NotSensitive();
            host.CheckValue(args, nameof(args));
            host.CheckUserArg(args.MaxAucExamples >= -1, nameof(args.MaxAucExamples), "Must be at least -1");
            host.CheckUserArg(args.NumRocExamples >= 0, nameof(args.NumRocExamples), "Must be non-negative");
            host.CheckUserArg(args.NumAuPrcExamples >= 0, nameof(args.NumAuPrcExamples), "Must be non-negative");
 
            _useRaw = args.UseRawScoreThreshold;
            _threshold = args.Threshold;
            _prCount = args.NumRocExamples;
            _aucCount = args.MaxAucExamples;
            _auPrcCount = args.NumAuPrcExamples;
        }
 
        private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
        {
            var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
            var host = Host.SchemaSensitive();
            var t = score.Type;
            if (t != NumberDataViewType.Single)
                throw host.ExceptSchemaMismatch(nameof(schema), "score", score.Name, "Single", t.ToString());
            host.Check(schema.Label.HasValue, "Could not find the label column");
            t = schema.Label.Value.Type;
            if (t != NumberDataViewType.Single && t != NumberDataViewType.Double && t != BooleanDataViewType.Instance && t.GetKeyCount() != 2)
                throw host.ExceptSchemaMismatch(nameof(schema), "label", schema.Label.Value.Name, "Single, Double, Boolean, or a Key with cardinality 2", t.ToString());
        }
 
        private protected override void CheckCustomColumnTypesCore(RoleMappedSchema schema)
        {
            var prob = schema.GetColumns(AnnotationUtils.Const.ScoreValueKind.Probability);
            var host = Host.SchemaSensitive();
            if (prob != null)
            {
                host.CheckParam(prob.Count == 1, nameof(schema), "Cannot have multiple probability columns");
                var probType = prob[0].Type;
                if (probType != NumberDataViewType.Single)
                    throw host.ExceptSchemaMismatch(nameof(schema), "probability", prob[0].Name, "Single", probType.ToString());
            }
            else if (!_useRaw)
            {
                throw host.ExceptParam(nameof(schema),
                    "Cannot compute the predicted label from the probability column because it does not exist");
            }
        }
 
        // Add also the probability column.
        private protected override Func<int, bool> GetActiveColsCore(RoleMappedSchema schema)
        {
            var pred = base.GetActiveColsCore(schema);
            var prob = schema.GetColumns(AnnotationUtils.Const.ScoreValueKind.Probability);
            Host.Assert(prob == null || prob.Count == 1);
            return i => Utils.Size(prob) > 0 && i == prob[0].Index || pred(i);
        }
 
        private protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
        {
            var classNames = GetClassNames(schema);
            return new Aggregator(Host, classNames, schema.Weight != null, _aucCount, _auPrcCount, _threshold, _useRaw, _prCount, stratName);
        }
 
        private ReadOnlyMemory<char>[] GetClassNames(RoleMappedSchema schema)
        {
            // Get the label names if they exist, or use the default names.
            var labelNames = default(VBuffer<ReadOnlyMemory<char>>);
            var labelCol = schema.Label.Value;
            if (labelCol.Type is KeyDataViewType &&
                labelCol.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type is VectorDataViewType vecType &&
                vecType.Size > 0 && vecType.ItemType == TextDataViewType.Instance)
            {
                labelCol.GetKeyValues(ref labelNames);
            }
            else
                labelNames = new VBuffer<ReadOnlyMemory<char>>(2, new[] { "positive".AsMemory(), "negative".AsMemory() });
 
            ReadOnlyMemory<char>[] names = new ReadOnlyMemory<char>[2];
            labelNames.CopyTo(names);
            return names;
        }
 
        private protected override IRowMapper CreatePerInstanceRowMapper(RoleMappedSchema schema)
        {
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.CheckParam(schema.Label != null, nameof(schema), "Could not find the label column");
            var scoreInfo = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
 
            var probInfos = schema.GetColumns(AnnotationUtils.Const.ScoreValueKind.Probability);
            var probCol = Utils.Size(probInfos) > 0 ? probInfos[0].Name : null;
            return new BinaryPerInstanceEvaluator(Host, schema.Schema, scoreInfo.Name, probCol, schema.Label.Value.Name, _threshold, _useRaw);
        }
 
        public override IEnumerable<MetricColumn> GetOverallMetricColumns()
        {
            yield return new MetricColumn("Accuracy", Accuracy);
            yield return new MetricColumn("PosPrec", PosPrecName);
            yield return new MetricColumn("PosRecall", PosRecallName);
            yield return new MetricColumn("NegPrec", NegPrecName);
            yield return new MetricColumn("NegRecall", NegRecallName);
            yield return new MetricColumn("AUC", Auc);
            yield return new MetricColumn("LogLoss", LogLoss, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("LogLossReduction", LogLossReduction);
            yield return new MetricColumn("Entropy", Entropy);
            yield return new MetricColumn("F1", F1);
            yield return new MetricColumn("AUPRC", AuPrc);
        }
 
        private protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, AggregatorDictionaryBase[] dictionaries,
            out Action<uint, ReadOnlyMemory<char>, Aggregator> addAgg, out Func<Dictionary<string, IDataView>> consolidate)
        {
            var stratCol = new List<uint>();
            var stratVal = new List<ReadOnlyMemory<char>>();
            var isWeighted = new List<bool>();
            var auc = new List<Double>();
            var accuracy = new List<Double>();
            var posPrec = new List<Double>();
            var posRecall = new List<Double>();
            var negPrec = new List<Double>();
            var negRecall = new List<Double>();
            var logLoss = new List<Double>();
            var logLossRed = new List<Double>();
            var entropy = new List<Double>();
            var f1 = new List<Double>();
            var auprc = new List<Double>();
 
            var counts = new List<Double[]>();
            var weights = new List<Double[]>();
            var confStratCol = new List<uint>();
            var confStratVal = new List<ReadOnlyMemory<char>>();
 
            var scores = new List<Single>();
            var precision = new List<Double>();
            var recall = new List<Double>();
            var fpr = new List<Double>();
            var weightedPrecision = new List<Double>();
            var weightedRecall = new List<Double>();
            var weightedFpr = new List<Double>();
            var prStratCol = new List<uint>();
            var prStratVal = new List<ReadOnlyMemory<char>>();
 
            bool hasStrats = Utils.Size(dictionaries) > 0;
            bool hasWeight = aggregator.Weighted;
 
            addAgg =
                (stratColKey, stratColVal, agg) =>
                {
                    Host.Check(agg.Weighted == hasWeight, "All aggregators must either be weighted or unweighted");
                    Host.Check((agg.AuPrcAggregator == null) == (aggregator.AuPrcAggregator == null),
                        "All aggregators must either compute AUPRC or not compute AUPRC");
 
                    agg.Finish();
                    stratCol.Add(stratColKey);
                    stratVal.Add(stratColVal);
                    isWeighted.Add(false);
                    auc.Add(agg.UnweightedAuc);
                    accuracy.Add(agg.UnweightedCounters.Acc);
                    posPrec.Add(agg.UnweightedCounters.PrecisionPos);
                    posRecall.Add(agg.UnweightedCounters.RecallPos);
                    negPrec.Add(agg.UnweightedCounters.PrecisionNeg);
                    negRecall.Add(agg.UnweightedCounters.RecallNeg);
                    logLoss.Add(agg.UnweightedCounters.LogLoss);
                    logLossRed.Add(agg.UnweightedCounters.LogLossReduction);
                    entropy.Add(agg.UnweightedCounters.Entropy);
                    f1.Add(agg.UnweightedCounters.F1);
                    if (agg.AuPrcAggregator != null)
                        auprc.Add(agg.UnweightedAuPrc);
 
                    confStratCol.AddRange(new[] { stratColKey, stratColKey });
                    confStratVal.AddRange(new[] { stratColVal, stratColVal });
                    counts.Add(new[] { agg.UnweightedCounters.NumTruePos, agg.UnweightedCounters.NumFalseNeg });
                    counts.Add(new[] { agg.UnweightedCounters.NumFalsePos, agg.UnweightedCounters.NumTrueNeg });
                    if (agg.Scores != null)
                    {
                        Host.AssertValue(agg.Precision);
                        Host.AssertValue(agg.Recall);
                        Host.AssertValue(agg.FalsePositiveRate);
 
                        scores.AddRange(agg.Scores);
                        precision.AddRange(agg.Precision);
                        recall.AddRange(agg.Recall);
                        fpr.AddRange(agg.FalsePositiveRate);
 
                        if (hasStrats)
                        {
                            prStratCol.AddRange(agg.Scores.Select(x => stratColKey));
                            prStratVal.AddRange(agg.Scores.Select(x => stratColVal));
                        }
                    }
                    if (agg.Weighted)
                    {
                        stratCol.Add(stratColKey);
                        stratVal.Add(stratColVal);
                        isWeighted.Add(true);
                        auc.Add(agg.WeightedAuc);
                        accuracy.Add(agg.WeightedCounters.Acc);
                        posPrec.Add(agg.WeightedCounters.PrecisionPos);
                        posRecall.Add(agg.WeightedCounters.RecallPos);
                        negPrec.Add(agg.WeightedCounters.PrecisionNeg);
                        negRecall.Add(agg.WeightedCounters.RecallNeg);
                        logLoss.Add(agg.WeightedCounters.LogLoss);
                        logLossRed.Add(agg.WeightedCounters.LogLossReduction);
                        entropy.Add(agg.WeightedCounters.Entropy);
                        f1.Add(agg.WeightedCounters.F1);
                        if (agg.AuPrcAggregator != null)
                            auprc.Add(agg.WeightedAuPrc);
                        weights.Add(new[] { agg.WeightedCounters.NumTruePos, agg.WeightedCounters.NumFalseNeg });
                        weights.Add(new[] { agg.WeightedCounters.NumFalsePos, agg.WeightedCounters.NumTrueNeg });
 
                        if (agg.Scores != null)
                        {
                            Host.AssertValue(agg.WeightedPrecision);
                            Host.AssertValue(agg.WeightedRecall);
                            Host.AssertValue(agg.WeightedFalsePositiveRate);
 
                            weightedPrecision.AddRange(agg.WeightedPrecision);
                            weightedRecall.AddRange(agg.WeightedRecall);
                            weightedFpr.AddRange(agg.WeightedFalsePositiveRate);
                        }
                    }
                };
 
            consolidate =
                () =>
                {
                    var overallDvBldr = new ArrayDataViewBuilder(Host);
                    if (hasStrats)
                    {
                        overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, stratCol.ToArray());
                        overallDvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, stratVal.ToArray());
                    }
                    if (hasWeight)
                        overallDvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BooleanDataViewType.Instance, isWeighted.ToArray());
                    overallDvBldr.AddColumn(Auc, NumberDataViewType.Double, auc.ToArray());
                    overallDvBldr.AddColumn(Accuracy, NumberDataViewType.Double, accuracy.ToArray());
                    overallDvBldr.AddColumn(PosPrecName, NumberDataViewType.Double, posPrec.ToArray());
                    overallDvBldr.AddColumn(PosRecallName, NumberDataViewType.Double, posRecall.ToArray());
                    overallDvBldr.AddColumn(NegPrecName, NumberDataViewType.Double, negPrec.ToArray());
                    overallDvBldr.AddColumn(NegRecallName, NumberDataViewType.Double, negRecall.ToArray());
                    overallDvBldr.AddColumn(LogLoss, NumberDataViewType.Double, logLoss.ToArray());
                    overallDvBldr.AddColumn(LogLossReduction, NumberDataViewType.Double, logLossRed.ToArray());
                    overallDvBldr.AddColumn(Entropy, NumberDataViewType.Double, entropy.ToArray());
                    overallDvBldr.AddColumn(F1, NumberDataViewType.Double, f1.ToArray());
                    if (aggregator.AuPrcAggregator != null)
                        overallDvBldr.AddColumn(AuPrc, NumberDataViewType.Double, auprc.ToArray());
 
                    var confDvBldr = new ArrayDataViewBuilder(Host);
                    if (hasStrats)
                    {
                        confDvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, confStratCol.ToArray());
                        confDvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, confStratVal.ToArray());
                    }
                    ValueGetter<VBuffer<ReadOnlyMemory<char>>> getSlotNames =
                        (ref VBuffer<ReadOnlyMemory<char>> dst) =>
                            dst = new VBuffer<ReadOnlyMemory<char>>(aggregator.ClassNames.Length, aggregator.ClassNames);
                    confDvBldr.AddColumn(MetricKinds.ColumnNames.Count, getSlotNames, NumberDataViewType.Double, counts.ToArray());
 
                    if (hasWeight)
                        confDvBldr.AddColumn(MetricKinds.ColumnNames.Weight, getSlotNames, NumberDataViewType.Double, weights.ToArray());
 
                    var result = new Dictionary<string, IDataView>();
                    result.Add(MetricKinds.OverallMetrics, overallDvBldr.GetDataView());
                    result.Add(MetricKinds.ConfusionMatrix, confDvBldr.GetDataView());
 
                    if (scores.Count > 0)
                    {
                        var dvBldr = new ArrayDataViewBuilder(Host);
                        if (hasStrats)
                        {
                            dvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, GetKeyValueGetter(dictionaries), (ulong)dictionaries.Length, prStratCol.ToArray());
                            dvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextDataViewType.Instance, prStratVal.ToArray());
                        }
                        dvBldr.AddColumn(Threshold, NumberDataViewType.Single, scores.ToArray());
                        dvBldr.AddColumn(Precision, NumberDataViewType.Double, precision.ToArray());
                        dvBldr.AddColumn(Recall, NumberDataViewType.Double, recall.ToArray());
                        dvBldr.AddColumn(FalsePositiveRate, NumberDataViewType.Double, fpr.ToArray());
                        if (weightedPrecision.Count > 0)
                        {
                            dvBldr.AddColumn("Weighted " + Precision, NumberDataViewType.Double, weightedPrecision.ToArray());
                            dvBldr.AddColumn("Weighted " + Recall, NumberDataViewType.Double, weightedRecall.ToArray());
                            dvBldr.AddColumn("Weighted " + FalsePositiveRate, NumberDataViewType.Double, weightedFpr.ToArray());
                        }
                        result.Add(PrCurve, dvBldr.GetDataView());
                    }
                    return result;
                };
        }
 
        public sealed class Aggregator : AggregatorBase
        {
            public sealed class Counters
            {
                private readonly bool _useRaw;
                private readonly Single _threshold;
 
                public Double NumTruePos;
                public Double NumTrueNeg;
                public Double NumFalsePos;
                public Double NumFalseNeg;
                private Double _numLogLossPositives;
                private Double _numLogLossNegatives;
                private Double _logLoss;
 
                public Double Acc
                {
                    get
                    {
                        return (NumTrueNeg + NumTruePos) / (NumTruePos + NumTrueNeg + NumFalseNeg + NumFalsePos);
                    }
                }
 
                public Double RecallPos
                {
                    get
                    {
                        return (NumTruePos + NumFalseNeg > 0) ? NumTruePos / (NumTruePos + NumFalseNeg) : 0;
                    }
                }
 
                public Double PrecisionPos
                {
                    get
                    {
                        return (NumTruePos + NumFalsePos > 0) ? NumTruePos / (NumTruePos + NumFalsePos) : 0;
                    }
                }
 
                public Double RecallNeg
                {
                    get
                    {
                        return (NumTrueNeg + NumFalsePos > 0) ? NumTrueNeg / (NumTrueNeg + NumFalsePos) : 0;
                    }
                }
 
                public Double PrecisionNeg
                {
                    get
                    {
                        return (NumTrueNeg + NumFalseNeg > 0) ? NumTrueNeg / (NumTrueNeg + NumFalseNeg) : 0;
                    }
                }
 
                public Double Entropy
                {
                    get
                    {
                        return MathUtils.Entropy((NumTruePos + NumFalseNeg) /
                            (NumTruePos + NumTrueNeg + NumFalseNeg + NumFalsePos));
                    }
                }
 
                public Double LogLoss
                {
                    get
                    {
                        return Double.IsNaN(_logLoss) ? Double.NaN : (_numLogLossPositives + _numLogLossNegatives > 0)
                            ? _logLoss / (_numLogLossPositives + _numLogLossNegatives) : 0;
                    }
                }
 
                public Double LogLossReduction
                {
                    get
                    {
                        if (_numLogLossPositives + _numLogLossNegatives == 0)
                            return 0;
                        var logLoss = _logLoss / (_numLogLossPositives + _numLogLossNegatives);
                        var priorPos = _numLogLossPositives / (_numLogLossPositives + _numLogLossNegatives);
                        var priorLogLoss = MathUtils.Entropy(priorPos);
                        return (priorLogLoss - logLoss) / priorLogLoss;
                    }
                }
 
                public Double F1
                {
                    get
                    {
                        var precisionPlusRecall = PrecisionPos + RecallPos;
                        if (precisionPlusRecall == 0)
                            return 0;
                        return 2 * PrecisionPos * RecallPos / precisionPlusRecall;
                    }
                }
 
                public Counters(bool useRaw, Single threshold)
                {
                    _useRaw = useRaw;
                    _threshold = threshold;
                }
 
                public void Update(Single score, Single prob, Single label, Double logloss, Single weight)
                {
                    bool predictPositive = _useRaw ? score > _threshold : prob > _threshold;
 
                    if (label > 0)
                    {
                        if (predictPositive)
                            NumTruePos += weight;
                        else
                            NumFalseNeg += weight;
                    }
                    else
                    {
                        if (predictPositive)
                            NumFalsePos += weight;
                        else
                            NumTrueNeg += weight;
                    }
 
                    if (!Single.IsNaN(prob))
                    {
                        if (label > 0)
                            _numLogLossPositives += weight;
                        else
                            _numLogLossNegatives += weight;
                    }
 
                    _logLoss += logloss * weight;
                }
            }
 
            private struct RocInfo
            {
                public Single Score;
                public Single Label;
                public Single Weight;
            }
 
            private readonly ReservoirSamplerWithoutReplacement<RocInfo> _prCurveReservoir;
            public readonly List<Single> Scores;
            public readonly List<Double> Precision;
            public readonly List<Double> Recall;
            public readonly List<Double> FalsePositiveRate;
            public readonly List<Double> WeightedPrecision;
            public readonly List<Double> WeightedRecall;
            public readonly List<Double> WeightedFalsePositiveRate;
 
            internal readonly AuPrcAggregatorBase AuPrcAggregator;
            public double WeightedAuPrc;
            public double UnweightedAuPrc;
 
            private readonly AucAggregatorBase _aucAggregator;
            public double WeightedAuc;
            public double UnweightedAuc;
 
            public readonly Counters UnweightedCounters;
            public readonly Counters WeightedCounters;
 
            public readonly bool Weighted;
 
            private ValueGetter<Single> _labelGetter;
            private ValueGetter<Single> _scoreGetter;
            private ValueGetter<Single> _weightGetter;
            private ValueGetter<Single> _probGetter;
            private Single _score;
            private Single _label;
            private Single _weight;
 
            public readonly ReadOnlyMemory<char>[] ClassNames;
 
            public Aggregator(IHostEnvironment env, ReadOnlyMemory<char>[] classNames, bool weighted, int aucReservoirSize,
                int auPrcReservoirSize, Single threshold, bool useRaw, int prCount, string stratName)
                : base(env, stratName)
            {
                Host.Assert(Utils.Size(classNames) == 2);
                Host.Assert(aucReservoirSize >= -1);
                Host.Assert(prCount >= 0);
                Host.Assert(auPrcReservoirSize >= 0);
                Host.Assert(useRaw || 0 <= threshold && threshold <= 1);
 
                ClassNames = classNames;
                UnweightedCounters = new Counters(useRaw, threshold);
                WeightedCounters = weighted ? new Counters(useRaw, threshold) : null;
                Weighted = weighted;
                if (weighted)
                {
                    _aucAggregator = new WeightedAucAggregator(Host.Rand, aucReservoirSize);
                    if (auPrcReservoirSize > 0)
                        AuPrcAggregator = new WeightedAuPrcAggregator(Host.Rand, auPrcReservoirSize);
                }
                else
                {
                    _aucAggregator = new UnweightedAucAggregator(Host.Rand, aucReservoirSize);
                    if (auPrcReservoirSize > 0)
                        AuPrcAggregator = new UnweightedAuPrcAggregator(Host.Rand, auPrcReservoirSize);
                }
 
                if (prCount > 0)
                {
                    ValueGetter<RocInfo> prSampleGetter =
                        (ref RocInfo dst) =>
                        {
                            dst.Label = _label;
                            dst.Score = _score;
                            dst.Weight = _weight;
                        };
                    _prCurveReservoir = new ReservoirSamplerWithoutReplacement<RocInfo>(Host.Rand, prCount, prSampleGetter);
                    Precision = new List<Double>();
                    Recall = new List<Double>();
                    FalsePositiveRate = new List<Double>();
                    Scores = new List<Single>();
                    if (weighted)
                    {
                        WeightedPrecision = new List<Double>();
                        WeightedRecall = new List<Double>();
                        WeightedFalsePositiveRate = new List<Double>();
                    }
                }
            }
 
            internal override void InitializeNextPass(DataViewRow row, RoleMappedSchema schema)
            {
                Host.Assert(schema.Label.HasValue);
                Host.Assert(PassNum < 1);
 
                var score = schema.GetUniqueColumn(AnnotationUtils.Const.ScoreValueKind.Score);
 
                _labelGetter = RowCursorUtils.GetLabelGetter(row, schema.Label.Value.Index);
                _scoreGetter = row.GetGetter<Single>(score);
                Host.AssertValue(_labelGetter);
                Host.AssertValue(_scoreGetter);
 
                var prob = schema.GetColumns(new RoleMappedSchema.ColumnRole(AnnotationUtils.Const.ScoreValueKind.Probability));
                Host.Assert(prob == null || prob.Count == 1);
 
                if (prob != null)
                    _probGetter = row.GetGetter<Single>(prob[0]);
                else
                    _probGetter = (ref Single value) => value = Single.NaN;
 
                Host.Assert((schema.Weight != null) == Weighted);
                if (Weighted)
                    _weightGetter = row.GetGetter<Single>(schema.Weight.Value);
            }
 
            public override void ProcessRow()
            {
                _labelGetter(ref _label);
                _scoreGetter(ref _score);
                if (!FloatUtils.IsFinite(_score))
                {
                    NumBadScores++;
                    return;
                }
                if (Single.IsNaN(_label))
                {
                    NumUnlabeledInstances++;
                    return;
                }
 
                Single prob = 0;
                _probGetter(ref prob);
 
                Double logloss;
                if (!Single.IsNaN(prob))
                {
                    if (_label > 0)
                    {
                        // REVIEW: Should we bring back the option to use ln instead of log2?
                        logloss = -Math.Log(prob, 2);
                    }
                    else
                        logloss = -Math.Log(1.0 - prob, 2);
                }
                else
                    logloss = Double.NaN;
 
                UnweightedCounters.Update(_score, prob, _label, logloss, 1);
 
                Host.Assert((_weightGetter != null) == Weighted);
                if (_weightGetter != null)
                {
                    _weightGetter(ref _weight);
                    if (!FloatUtils.IsFinite(_weight))
                    {
                        NumBadWeights++;
                        _weight = 1;
                    }
                    _aucAggregator.ProcessRow(_label, _score, _weight);
                    WeightedCounters.Update(_score, prob, _label, logloss, _weight);
                }
                else
                    _aucAggregator.ProcessRow(_label, _score);
 
                if (_prCurveReservoir != null)
                    _prCurveReservoir.Sample();
                if (AuPrcAggregator != null)
                    AuPrcAggregator.ProcessRow(_label, _score, _weight);
            }
 
            public void Finish()
            {
                Contracts.Assert(!IsActive());
 
                _aucAggregator.Finish();
                WeightedAuc = _aucAggregator.ComputeWeightedAuc(out UnweightedAuc);
                if (AuPrcAggregator != null)
                    WeightedAuPrc = AuPrcAggregator.ComputeWeightedAuPrc(out UnweightedAuPrc);
                FinishOtherMetrics();
            }
 
            private void FinishOtherMetrics()
            {
                if (_prCurveReservoir != null)
                    ComputePrCurves();
            }
 
            private void ComputePrCurves()
            {
                Host.AssertValue(_prCurveReservoir);
                Host.AssertValue(Scores);
                Host.AssertValue(Precision);
                Host.AssertValue(Recall);
                Host.AssertValue(FalsePositiveRate);
 
                _prCurveReservoir.Lock();
                var prSample = _prCurveReservoir.GetSample();
                Scores.Clear();
                Precision.Clear();
                Recall.Clear();
                FalsePositiveRate.Clear();
                if (Weighted)
                {
                    Host.AssertValue(WeightedPrecision);
                    Host.AssertValue(WeightedRecall);
                    Host.AssertValue(WeightedFalsePositiveRate);
 
                    WeightedPrecision.Clear();
                    WeightedRecall.Clear();
                    WeightedFalsePositiveRate.Clear();
                }
 
                Double pos = 0;
                Double neg = 0;
                Double wpos = 0;
                Double wneg = 0;
                Single scoreCur = Single.PositiveInfinity;
                foreach (var point in prSample.OrderByDescending(x => x.Score)
                    .Concat(new[] { new RocInfo() { Score = Single.NegativeInfinity } }))
                {
                    // Add the next point to the precision/recall/fpr lists.
                    if (point.Score < scoreCur)
                    {
                        if (pos + neg > 0)
                        {
                            Scores.Add(scoreCur);
                            Precision.Add(pos / (pos + neg));
                            Recall.Add(pos);
                            FalsePositiveRate.Add(neg);
                            if (Weighted)
                            {
                                WeightedPrecision.Add(wpos / (wpos + wneg));
                                WeightedRecall.Add(wpos);
                                WeightedFalsePositiveRate.Add(wneg);
                            }
                        }
                        scoreCur = point.Score;
                    }
                    if (Single.IsNegativeInfinity(point.Score))
                        continue;
 
                    if (point.Label > 0)
                        pos++;
                    else
                        neg++;
                    if (Weighted)
                    {
                        if (point.Label > 0)
                            wpos += point.Weight;
                        else
                            wneg += point.Weight;
                    }
                }
 
                // normalize recall and false positive rate
                for (int i = 0; i < Recall.Count; i++)
                {
                    Recall[i] /= pos;
                    FalsePositiveRate[i] /= neg;
                }
                if (Weighted)
                {
                    for (int i = 0; i < WeightedRecall.Count; i++)
                    {
                        WeightedRecall[i] /= wpos;
                        WeightedFalsePositiveRate[i] /= wneg;
                    }
                }
            }
        }
 
        /// <summary>
        /// Evaluates scored binary classification data.
        /// </summary>
        /// <param name="data">The scored data.</param>
        /// <param name="label">The name of the label column in <paramref name="data"/>.</param>
        /// <param name="score">The name of the score column in <paramref name="data"/>.</param>
        /// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
        /// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
        /// <returns>The evaluation results for these calibrated outputs.</returns>
        public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string label, string score, string probability, string predictedLabel)
        {
            Host.CheckValue(data, nameof(data));
            Host.CheckNonEmpty(label, nameof(label));
            Host.CheckNonEmpty(score, nameof(score));
            Host.CheckNonEmpty(probability, nameof(probability));
            Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
 
            var roles = new RoleMappedData(data, opt: false,
                RoleMappedSchema.ColumnRole.Label.Bind(label),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Probability, probability),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));
 
            var resultDict = ((IEvaluator)this).Evaluate(roles);
            Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
            var overall = resultDict[MetricKinds.OverallMetrics];
            var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];
 
            CalibratedBinaryClassificationMetrics result;
            using (var cursor = overall.GetRowCursorForAllColumns())
            {
                var moved = cursor.MoveNext();
                Host.Assert(moved);
                result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix);
                moved = cursor.MoveNext();
                Host.Assert(!moved);
            }
 
            return result;
        }
 
        /// <summary>
        /// Evaluates scored binary classification data and generates precision recall curve data.
        /// </summary>
        /// <param name="data">The scored data.</param>
        /// <param name="label">The name of the label column in <paramref name="data"/>.</param>
        /// <param name="score">The name of the score column in <paramref name="data"/>.</param>
        /// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
        /// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
        /// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
        /// <returns>The evaluation results for these calibrated outputs.</returns>
        public CalibratedBinaryClassificationMetrics EvaluateWithPRCurve(
            IDataView data,
            string label,
            string score,
            string probability,
            string predictedLabel,
            out List<BinaryPrecisionRecallDataPoint> prCurve)
        {
            Host.CheckValue(data, nameof(data));
            Host.CheckNonEmpty(label, nameof(label));
            Host.CheckNonEmpty(score, nameof(score));
            Host.CheckNonEmpty(probability, nameof(probability));
            Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
 
            var roles = new RoleMappedData(data, opt: false,
                RoleMappedSchema.ColumnRole.Label.Bind(label),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Probability, probability),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));
 
            var resultDict = ((IEvaluator)this).Evaluate(roles);
            Host.Assert(resultDict.ContainsKey(MetricKinds.PrCurve));
            var prCurveView = resultDict[MetricKinds.PrCurve];
            Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
            var overall = resultDict[MetricKinds.OverallMetrics];
 
            var prCurveResult = new List<BinaryPrecisionRecallDataPoint>();
            using (var cursor = prCurveView.GetRowCursorForAllColumns())
            {
                GetPrecisionRecallDataPointGetters(prCurveView, cursor,
                    out ValueGetter<float> thresholdGetter,
                    out ValueGetter<double> precisionGetter,
                    out ValueGetter<double> recallGetter,
                    out ValueGetter<double> fprGetter);
 
                while (cursor.MoveNext())
                {
                    prCurveResult.Add(new BinaryPrecisionRecallDataPoint(thresholdGetter, precisionGetter, recallGetter, fprGetter));
                }
            }
            prCurve = prCurveResult;
            var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];
 
            CalibratedBinaryClassificationMetrics result;
            using (var cursor = overall.GetRowCursorForAllColumns())
            {
                var moved = cursor.MoveNext();
                Host.Assert(moved);
                result = new CalibratedBinaryClassificationMetrics(Host, cursor, confusionMatrix);
                moved = cursor.MoveNext();
                Host.Assert(!moved);
            }
 
            return result;
        }
 
        private void GetPrecisionRecallDataPointGetters(IDataView prCurveView,
            DataViewRowCursor cursor,
            out ValueGetter<float> thresholdGetter,
            out ValueGetter<double> precisionGetter,
            out ValueGetter<double> recallGetter,
            out ValueGetter<double> fprGetter)
        {
            var thresholdColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Threshold);
            var precisionColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Precision);
            var recallColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Recall);
            var fprColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.FalsePositiveRate);
            Host.Assert(thresholdColumn != null);
            Host.Assert(precisionColumn != null);
            Host.Assert(recallColumn != null);
            Host.Assert(fprColumn != null);
 
            thresholdGetter = cursor.GetGetter<float>((DataViewSchema.Column)thresholdColumn);
            precisionGetter = cursor.GetGetter<double>((DataViewSchema.Column)precisionColumn);
            recallGetter = cursor.GetGetter<double>((DataViewSchema.Column)recallColumn);
            fprGetter = cursor.GetGetter<double>((DataViewSchema.Column)fprColumn);
        }
 
        /// <summary>
        /// Evaluates scored binary classification data, without probability-based metrics.
        /// </summary>
        /// <param name="data">The scored data.</param>
        /// <param name="label">The name of the label column in <paramref name="data"/>.</param>
        /// <param name="score">The name of the score column in <paramref name="data"/>.</param>
        /// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
        /// <returns>The evaluation results for these uncalibrated outputs.</returns>
        /// <seealso cref="Evaluate(IDataView, string, string, string)"/>
        public BinaryClassificationMetrics Evaluate(IDataView data, string label, string score, string predictedLabel)
        {
            Host.CheckValue(data, nameof(data));
            Host.CheckNonEmpty(label, nameof(label));
            Host.CheckNonEmpty(score, nameof(score));
            Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
 
            var roles = new RoleMappedData(data, opt: false,
                RoleMappedSchema.ColumnRole.Label.Bind(label),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));
 
            var resultDict = ((IEvaluator)this).Evaluate(roles);
            Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
            var overall = resultDict[MetricKinds.OverallMetrics];
            var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];
 
            BinaryClassificationMetrics result;
            using (var cursor = overall.GetRowCursorForAllColumns())
            {
                var moved = cursor.MoveNext();
                Host.Assert(moved);
                result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix);
                moved = cursor.MoveNext();
                Host.Assert(!moved);
            }
 
            return result;
        }
 
        /// <summary>
        /// Evaluates scored binary classification data, without probability-based metrics
        /// and generates precision recall curve data.
        /// </summary>
        /// <param name="data">The scored data.</param>
        /// <param name="label">The name of the label column in <paramref name="data"/>.</param>
        /// <param name="score">The name of the score column in <paramref name="data"/>.</param>
        /// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
        /// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
        /// <returns>The evaluation results for these uncalibrated outputs.</returns>
        /// <seealso cref="Evaluate(IDataView, string, string, string)"/>
        public BinaryClassificationMetrics EvaluateWithPRCurve(
            IDataView data,
            string label,
            string score,
            string predictedLabel,
            out List<BinaryPrecisionRecallDataPoint> prCurve)
        {
            Host.CheckValue(data, nameof(data));
            Host.CheckNonEmpty(label, nameof(label));
            Host.CheckNonEmpty(score, nameof(score));
            Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
 
            var roles = new RoleMappedData(data, opt: false,
                RoleMappedSchema.ColumnRole.Label.Bind(label),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score),
                RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));
 
            var resultDict = ((IEvaluator)this).Evaluate(roles);
            Host.Assert(resultDict.ContainsKey(MetricKinds.PrCurve));
            var prCurveView = resultDict[MetricKinds.PrCurve];
            Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
            var overall = resultDict[MetricKinds.OverallMetrics];
            var confusionMatrix = resultDict[MetricKinds.ConfusionMatrix];
 
            var prCurveResult = new List<BinaryPrecisionRecallDataPoint>();
            using (var cursor = prCurveView.GetRowCursorForAllColumns())
            {
                GetPrecisionRecallDataPointGetters(prCurveView, cursor,
                    out ValueGetter<float> thresholdGetter,
                    out ValueGetter<double> precisionGetter,
                    out ValueGetter<double> recallGetter,
                    out ValueGetter<double> fprGetter);
 
                while (cursor.MoveNext())
                {
                    prCurveResult.Add(new BinaryPrecisionRecallDataPoint(thresholdGetter, precisionGetter, recallGetter, fprGetter));
                }
            }
            prCurve = prCurveResult;
 
            BinaryClassificationMetrics result;
            using (var cursor = overall.GetRowCursorForAllColumns())
            {
                var moved = cursor.MoveNext();
                Host.Assert(moved);
                result = new BinaryClassificationMetrics(Host, cursor, confusionMatrix);
                moved = cursor.MoveNext();
                Host.Assert(!moved);
            }
 
            return result;
        }
    }
 
    internal sealed class BinaryPerInstanceEvaluator : PerInstanceEvaluatorBase
    {
        public const string LoaderSignature = "BinaryPerInstance";
        private static VersionInfo GetVersionInfo()
        {
            return new VersionInfo(
                modelSignature: "BIN INST",
                verWrittenCur: 0x00010001, // Initial
                verReadableCur: 0x00010001,
                verWeCanReadBack: 0x00010001,
                loaderSignature: LoaderSignature,
                loaderAssemblyName: typeof(BinaryPerInstanceEvaluator).Assembly.FullName);
        }
 
        private const int AssignedCol = 0;
        private const int LogLossCol = 1;
 
        public const string LogLoss = "Log-loss";
        public const string Assigned = "Assigned";
 
        private readonly string _probCol;
        private readonly int _probIndex;
        private readonly Single _threshold;
        private readonly bool _useRaw;
        private readonly DataViewType[] _types;
 
        public BinaryPerInstanceEvaluator(IHostEnvironment env, DataViewSchema schema, string scoreCol, string probCol, string labelCol, Single threshold, bool useRaw)
            : base(env, schema, scoreCol, labelCol)
        {
            _threshold = threshold;
            _useRaw = useRaw;
 
            using (var ch = Host.Start("Finding Input Columns"))
            {
                _probCol = probCol;
                _probIndex = -1;
                if (string.IsNullOrEmpty(_probCol) || !schema.TryGetColumnIndex(_probCol, out _probIndex))
                    ch.Warning("Data does not contain a probability column. Will not output the Log-loss column");
                CheckInputColumnTypes(schema);
            }
 
            _types = new DataViewType[2];
            _types[LogLossCol] = NumberDataViewType.Double;
            _types[AssignedCol] = BooleanDataViewType.Instance;
        }
 
        private BinaryPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
            : base(env, ctx, schema)
        {
            // *** Binary format **
            // base
            // int: Id of the probability column name
            // float: _threshold
            // byte: _useRaw
 
            _probCol = ctx.LoadStringOrNull();
            _probIndex = -1;
            if (_probCol != null && !schema.TryGetColumnIndex(_probCol, out _probIndex))
                throw Host.ExceptParam(nameof(schema), "Did not find the probability column '{0}'", _probCol);
 
            CheckInputColumnTypes(schema);
 
            _threshold = ctx.Reader.ReadFloat();
            _useRaw = ctx.Reader.ReadBoolByte();
            Host.CheckDecode(!string.IsNullOrEmpty(_probCol) || _useRaw);
            Host.CheckDecode(FloatUtils.IsFinite(_threshold));
 
            _types = new DataViewType[2];
            _types[LogLossCol] = NumberDataViewType.Double;
            _types[AssignedCol] = BooleanDataViewType.Instance;
        }
 
        public static BinaryPerInstanceEvaluator Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());
 
            return new BinaryPerInstanceEvaluator(env, ctx, schema);
        }
 
        private protected override void SaveModel(ModelSaveContext ctx)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel();
            ctx.SetVersionInfo(GetVersionInfo());
 
            // *** Binary format **
            // base
            // int: Id of the probability column name
            // float: _threshold
            // byte: _useRaw
 
            base.SaveModel(ctx);
            ctx.SaveStringOrNull(_probCol);
            Contracts.Assert(FloatUtils.IsFinite(_threshold));
            ctx.Writer.Write(_threshold);
            Contracts.Assert(!string.IsNullOrEmpty(_probCol) || _useRaw);
            ctx.Writer.WriteBoolByte(_useRaw);
        }
 
        private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
        {
            if (_probIndex >= 0)
            {
                return
                    col =>
                        activeOutput(LogLossCol) && (col == _probIndex || col == LabelIndex) ||
                        activeOutput(AssignedCol) && (_useRaw && col == ScoreIndex || !_useRaw && col == _probIndex);
            }
            Host.Assert(_useRaw);
            return col => activeOutput(AssignedCol) && col == ScoreIndex;
        }
 
        private protected override Delegate[] CreateGettersCore(DataViewRow input, Func<int, bool> activeCols, out Action disposer)
        {
            Host.Assert(LabelIndex >= 0);
            Host.Assert(ScoreIndex >= 0);
            Host.Assert(_probIndex >= 0 || _useRaw);
 
            disposer = null;
 
            long cachedPosition = -1;
            Single label = 0;
            Single prob = 0;
            Single score = 0;
 
            ValueGetter<Single> nanGetter = (ref Single value) => value = Single.NaN;
            var labelGetter = _probIndex >= 0 && activeCols(LogLossCol) ?
                RowCursorUtils.GetLabelGetter(input, LabelIndex) : nanGetter;
            ValueGetter<Single> probGetter;
            if (_probIndex >= 0 && activeCols(LogLossCol))
                probGetter = input.GetGetter<Single>(input.Schema[_probIndex]);
            else
                probGetter = nanGetter;
            ValueGetter<Single> scoreGetter;
            if (activeCols(AssignedCol) && ScoreIndex >= 0)
                scoreGetter = input.GetGetter<Single>(input.Schema[ScoreIndex]);
            else
                scoreGetter = nanGetter;
 
            Action updateCacheIfNeeded;
            Func<bool> getPredictedLabel;
            if (_useRaw)
            {
                updateCacheIfNeeded =
                    () =>
                    {
                        if (cachedPosition != input.Position)
                        {
                            labelGetter(ref label);
                            probGetter(ref prob);
                            scoreGetter(ref score);
                            cachedPosition = input.Position;
                        }
                    };
                getPredictedLabel = () => GetPredictedLabel(score);
            }
            else
            {
                updateCacheIfNeeded =
                    () =>
                    {
                        if (cachedPosition != input.Position)
                        {
                            labelGetter(ref label);
                            probGetter(ref prob);
                            cachedPosition = input.Position;
                        }
                    };
                getPredictedLabel = () => GetPredictedLabel(prob);
            }
 
            var getters = _probIndex >= 0 ? new Delegate[2] : new Delegate[1];
            if (activeCols(AssignedCol))
            {
                ValueGetter<bool> predFn =
                    (ref bool dst) =>
                    {
                        updateCacheIfNeeded();
                        dst = getPredictedLabel();
                    };
                getters[_probIndex >= 0 ? AssignedCol : 0] = predFn;
            }
            if (_probIndex >= 0 && activeCols(LogLossCol))
            {
                ValueGetter<Double> loglossFn =
                    (ref Double dst) =>
                    {
                        updateCacheIfNeeded();
                        dst = GetLogLoss(prob, label);
                    };
                getters[LogLossCol] = loglossFn;
            }
            return getters;
        }
 
        private Double GetLogLoss(Single prob, Single label)
        {
            if (Single.IsNaN(prob) || Single.IsNaN(label))
                return Double.NaN;
            if (label > 0)
                return -Math.Log(prob, 2);
            return -Math.Log(1.0 - prob, 2);
        }
 
        private bool GetPredictedLabel(Single val)
        {
            //Behavior for NA values is undefined.
            return Single.IsNaN(val) ? false : val > _threshold;
        }
 
        private protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
        {
            if (_probIndex >= 0)
            {
                var infos = new DataViewSchema.DetachedColumn[2];
                infos[LogLossCol] = new DataViewSchema.DetachedColumn(LogLoss, _types[LogLossCol], null);
                infos[AssignedCol] = new DataViewSchema.DetachedColumn(Assigned, _types[AssignedCol], null);
                return infos;
            }
            return new[] { new DataViewSchema.DetachedColumn(Assigned, _types[AssignedCol], null), };
        }
 
        private void CheckInputColumnTypes(DataViewSchema schema)
        {
            Host.AssertNonEmpty(ScoreCol);
            Host.AssertValueOrNull(_probCol);
            Host.AssertNonEmpty(LabelCol);
 
            var t = schema[(int)LabelIndex].Type;
            if (t != NumberDataViewType.Single && t != NumberDataViewType.Double && t != BooleanDataViewType.Instance && t.GetKeyCount() != 2)
                throw Host.ExceptSchemaMismatch(nameof(schema), "label", LabelCol, "Single, Double, Boolean or a Key with cardinality 2", t.ToString());
 
            t = schema[ScoreIndex].Type;
            if (t != NumberDataViewType.Single)
                throw Host.ExceptSchemaMismatch(nameof(schema), "score", ScoreCol, "Single", t.ToString());
 
            if (_probIndex >= 0)
            {
                Host.Assert(!string.IsNullOrEmpty(_probCol));
                t = schema[_probIndex].Type;
                if (t != NumberDataViewType.Single)
                    throw Host.ExceptSchemaMismatch(nameof(schema), "probability", _probCol, "Single", t.ToString());
            }
            else if (!_useRaw)
                throw Host.Except("Cannot compute the predicted label from the probability column because it does not exist");
        }
    }
 
    [BestFriend]
    internal sealed class BinaryClassifierMamlEvaluator : MamlEvaluatorBase
    {
        public class Arguments : ArgumentsBase
        {
            [Argument(ArgumentType.AtMostOnce, HelpText = "Probability column name", ShortName = "prob")]
            public string ProbabilityColumn;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Probability value for classification thresholding")]
            public Single Threshold;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Use raw score value instead of probability for classification thresholding", ShortName = "useRawScore")]
            public bool UseRawScoreThreshold = true;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of samples to use for p/r curve generation. Specify 0 for no p/r curve generation", ShortName = "numpr")]
            public int NumRocExamples = 100000;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of samples to use for AUC calculation. If 0, AUC is not computed. If -1, the whole dataset is used", ShortName = "numauc")]
            public int MaxAucExamples = -1;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "The number of samples to use for AUPRC calculation. Specify 0 for no AUPRC calculation", ShortName = "numauprc")]
            public int NumAuPrcExamples = 100000;
 
            [Argument(ArgumentType.AtMostOnce, HelpText = "Precision-Recall results filename", ShortName = "pr", Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
            public string PRFilename;
        }
 
        private const string FoldAccuracy = "OVERALL 0/1 ACCURACY";
        private const string FoldLogLoss = "LOG LOSS/instance";
        private const string FoldLogLosRed = "LOG-LOSS REDUCTION (RIG)";
 
        private readonly BinaryClassifierEvaluator _evaluator;
 
        private readonly string _prFileName;
        private readonly string _probCol;
 
        private protected override IEvaluator Evaluator => _evaluator;
 
        public BinaryClassifierMamlEvaluator(IHostEnvironment env, Arguments args)
            : base(args, env, AnnotationUtils.Const.ScoreColumnKind.BinaryClassification, "BinaryClassifierMamlEvaluator")
        {
            Host.CheckValue(args, nameof(args));
            Utils.CheckOptionalUserDirectory(args.PRFilename, nameof(args.PRFilename));
 
            var evalArgs = new BinaryClassifierEvaluator.Arguments();
            evalArgs.Threshold = args.Threshold;
            evalArgs.UseRawScoreThreshold = args.UseRawScoreThreshold;
            evalArgs.MaxAucExamples = args.MaxAucExamples;
            evalArgs.NumRocExamples = string.IsNullOrEmpty(args.PRFilename) ? 0 : args.NumRocExamples;
            evalArgs.NumAuPrcExamples = args.NumAuPrcExamples;
 
            _prFileName = args.PRFilename;
            _probCol = args.ProbabilityColumn;
            _evaluator = new BinaryClassifierEvaluator(Host, evalArgs);
        }
 
        private protected override IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRolesCore(RoleMappedSchema schema)
        {
            var cols = base.GetInputColumnRolesCore(schema);
 
            var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
                AnnotationUtils.Const.ScoreColumnKind.BinaryClassification);
 
            // Get the optional probability column.
            var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn),
                scoreCol.Index, AnnotationUtils.Const.ScoreValueKind.Probability, NumberDataViewType.Single.Equals);
            if (probCol.HasValue)
                cols = AnnotationUtils.Prepend(cols, RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Probability, probCol.Value.Name));
            return cols;
        }
 
        private protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDataView> metrics)
        {
            ch.AssertValue(metrics);
 
            IDataView fold;
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold))
                throw ch.Except("No overall metrics found");
 
            IDataView conf;
            if (!metrics.TryGetValue(MetricKinds.ConfusionMatrix, out conf))
                throw ch.Except("No overall metrics found");
 
            (string name, string source)[] cols =
            {
                (FoldAccuracy, BinaryClassifierEvaluator.Accuracy),
                (FoldLogLoss, BinaryClassifierEvaluator.LogLoss),
                (FoldLogLosRed, BinaryClassifierEvaluator.LogLossReduction)
            };
 
            var colsToKeep = new List<string>();
            colsToKeep.Add(FoldAccuracy);
            colsToKeep.Add(FoldLogLoss);
            colsToKeep.Add(BinaryClassifierEvaluator.Entropy);
            colsToKeep.Add(FoldLogLosRed);
            colsToKeep.Add(BinaryClassifierEvaluator.Auc);
 
            int index;
            if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out index))
                colsToKeep.Add(MetricKinds.ColumnNames.IsWeighted);
            if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out index))
                colsToKeep.Add(MetricKinds.ColumnNames.StratCol);
            if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out index))
                colsToKeep.Add(MetricKinds.ColumnNames.StratVal);
 
            fold = new ColumnCopyingTransformer(Host, cols).Transform(fold);
 
            // Select the columns that are specified in the Copy
            fold = ColumnSelectingTransformer.CreateKeep(Host, fold, colsToKeep.ToArray());
 
            string weightedConf;
            var unweightedConf = MetricWriter.GetConfusionTableAsFormattedString(Host, conf, out weightedConf);
            string weightedFold;
            var unweightedFold = MetricWriter.GetPerFoldResults(Host, fold, out weightedFold);
            ch.Assert(string.IsNullOrEmpty(weightedConf) == string.IsNullOrEmpty(weightedFold));
            if (!string.IsNullOrEmpty(weightedConf))
            {
                ch.Info(MessageSensitivity.None, weightedConf);
                ch.Info(MessageSensitivity.None, weightedFold);
            }
            ch.Info(MessageSensitivity.None, unweightedConf);
            ch.Info(MessageSensitivity.None, unweightedFold);
        }
 
        private protected override IDataView GetOverallResultsCore(IDataView overall)
        {
            return ColumnSelectingTransformer.CreateDrop(Host, overall, BinaryClassifierEvaluator.Entropy);
        }
 
        private protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)
        {
            ch.AssertNonEmpty(metrics);
 
            if (!string.IsNullOrEmpty(_prFileName))
            {
                IDataView pr;
                if (!TryGetPrMetrics(metrics, out pr))
                    throw ch.Except("Did not find p/r metrics");
 
                ch.Trace(MessageSensitivity.None, "Saving p/r data view");
                // If the data view contains stratification columns, filter so that only the overall metrics
                // will be present, and drop them.
                pr = MetricWriter.GetNonStratifiedMetrics(Host, pr);
                MetricWriter.SavePerInstance(Host, ch, _prFileName, pr);
            }
        }
 
        public override IEnumerable<MetricColumn> GetOverallMetricColumns()
        {
            yield return new MetricColumn("Accuracy", BinaryClassifierEvaluator.Accuracy);
            yield return new MetricColumn("PosPrec", BinaryClassifierEvaluator.PosPrecName);
            yield return new MetricColumn("PosRecall", BinaryClassifierEvaluator.PosRecallName);
            yield return new MetricColumn("NegPrec", BinaryClassifierEvaluator.NegPrecName);
            yield return new MetricColumn("NegRecall", BinaryClassifierEvaluator.NegRecallName);
            yield return new MetricColumn("Auc", BinaryClassifierEvaluator.Auc);
            yield return new MetricColumn("LogLoss", BinaryClassifierEvaluator.LogLoss, MetricColumn.Objective.Minimize);
            yield return new MetricColumn("LogLossReduction", BinaryClassifierEvaluator.LogLossReduction);
            yield return new MetricColumn("F1", BinaryClassifierEvaluator.F1);
            yield return new MetricColumn("AuPrc", BinaryClassifierEvaluator.AuPrc);
        }
 
        // This method saves the p/r plots, and returns the p/r metrics data view.
        // In case there are results from multiple folds, they are averaged using
        // vertical averaging for the p/r plot, and appended using AppendRowsDataView for
        // the p/r data view.
        private bool TryGetPrMetrics(Dictionary<string, IDataView>[] metrics, out IDataView pr)
        {
            Host.AssertNonEmpty(metrics);
            pr = null;
            var prList = new List<IDataView>();
            for (int i = 0; i < metrics.Length; i++)
            {
                var dict = metrics[i];
                IDataView idv;
                if (!dict.TryGetValue(BinaryClassifierEvaluator.PrCurve, out idv))
                    return false;
                if (metrics.Length != 1)
                    idv = EvaluateUtils.AddFoldIndex(Host, idv, i, metrics.Length);
                else
                    pr = idv;
                prList.Add(idv);
            }
            if (metrics.Length != 1)
                pr = AppendRowsDataView.Create(Host, prList[0].Schema, prList.ToArray());
 
            return true;
        }
 
        private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)
        {
            Host.CheckValue(schema, nameof(schema));
            Host.CheckParam(schema.Label.HasValue, nameof(schema), "Schema must contain a label column");
 
            // The binary classifier evaluator outputs the label, score and probability columns.
            yield return schema.Label.Value.Name;
            var scoreCol = EvaluateUtils.GetScoreColumn(Host, schema.Schema, ScoreCol, nameof(Arguments.ScoreColumn),
                AnnotationUtils.Const.ScoreColumnKind.BinaryClassification);
            yield return scoreCol.Name;
            var probCol = EvaluateUtils.GetOptAuxScoreColumn(Host, schema.Schema, _probCol, nameof(Arguments.ProbabilityColumn),
                scoreCol.Index, AnnotationUtils.Const.ScoreValueKind.Probability, NumberDataViewType.Single.Equals);
            // Return the output columns. The LogLoss column is returned only if the probability column exists.
            if (probCol.HasValue)
            {
                yield return probCol.Value.Name;
                yield return BinaryPerInstanceEvaluator.LogLoss;
            }
 
            // REVIEW: Identify by metadata.
            int col;
            if (schema.Schema.TryGetColumnIndex("FeatureContributions", out col))
                yield return "FeatureContributions";
 
            yield return BinaryPerInstanceEvaluator.Assigned;
        }
    }
 
    internal static partial class Evaluate
    {
        [TlcModule.EntryPoint(Name = "Models.BinaryClassificationEvaluator", Desc = "Evaluates a binary classification scored dataset.")]
        public static CommonOutputs.ClassificationEvaluateOutput Binary(IHostEnvironment env, BinaryClassifierMamlEvaluator.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("EvaluateBinary");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
 
            string label;
            string weight;
            string name;
            MatchColumns(host, input, out label, out weight, out name);
            IMamlEvaluator evaluator = new BinaryClassifierMamlEvaluator(host, input);
            var data = new RoleMappedData(input.Data, label, null, null, weight, name);
            var metrics = evaluator.Evaluate(data);
 
            var warnings = ExtractWarnings(host, metrics);
            var overallMetrics = ExtractOverallMetrics(host, metrics, evaluator);
            var perInstanceMetrics = evaluator.GetPerInstanceMetrics(data);
            var confusionMatrix = ExtractConfusionMatrix(host, metrics);
 
            return new CommonOutputs.ClassificationEvaluateOutput()
            {
                Warnings = warnings,
                OverallMetrics = overallMetrics,
                PerInstanceMetrics = perInstanceMetrics,
                ConfusionMatrix = confusionMatrix
            };
        }
 
        private static void MatchColumns(IHost host, MamlEvaluatorBase.ArgumentsBase input, out string label, out string weight, out string name)
        {
            var schema = input.Data.Schema;
            label = TrainUtils.MatchNameOrDefaultOrNull(host, schema,
                nameof(BinaryClassifierMamlEvaluator.Arguments.LabelColumn),
                input.LabelColumn, DefaultColumnNames.Label);
            weight = TrainUtils.MatchNameOrDefaultOrNull(host, schema,
                nameof(BinaryClassifierMamlEvaluator.Arguments.WeightColumn),
                input.WeightColumn, DefaultColumnNames.Weight);
            name = TrainUtils.MatchNameOrDefaultOrNull(host, schema,
                nameof(BinaryClassifierMamlEvaluator.Arguments.NameColumn),
                input.NameColumn, DefaultColumnNames.Name);
        }
 
        private static IDataView ExtractWarnings(IHost host, Dictionary<string, IDataView> metrics)
        {
            IDataView warnings;
            if (!metrics.TryGetValue(MetricKinds.Warnings, out warnings))
            {
                var schemaBuilder = new DataViewSchema.Builder();
                schemaBuilder.AddColumn(MetricKinds.ColumnNames.WarningText, TextDataViewType.Instance);
                warnings = new EmptyDataView(host, schemaBuilder.ToSchema());
            }
 
            return warnings;
        }
 
        private static IDataView ExtractOverallMetrics(IHost host, Dictionary<string, IDataView> metrics, IMamlEvaluator evaluator)
        {
            IDataView overallMetrics;
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out overallMetrics))
            {
                var schemaBuilder = new DataViewSchema.Builder();
                foreach (var mc in evaluator.GetOverallMetricColumns())
                    schemaBuilder.AddColumn(mc.LoadName, NumberDataViewType.Double);
 
                overallMetrics = new EmptyDataView(host, schemaBuilder.ToSchema());
            }
 
            return overallMetrics;
        }
 
        private static IDataView ExtractConfusionMatrix(IHost host, Dictionary<string, IDataView> metrics)
        {
            IDataView confusionMatrix;
            if (!metrics.TryGetValue(MetricKinds.ConfusionMatrix, out confusionMatrix))
            {
                var schemaBuilder = new DataViewSchema.Builder();
                schemaBuilder.AddColumn(MetricKinds.ColumnNames.Count, NumberDataViewType.Double);
                confusionMatrix = new EmptyDataView(host, schemaBuilder.ToSchema());
            }
 
            return confusionMatrix;
        }
    }
}