File: AutoFormerV2\ObjectDetectionMetrics.cs
Web Access
Project: src\src\Microsoft.ML.TorchSharp\Microsoft.ML.TorchSharp.csproj (Microsoft.ML.TorchSharp)
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
 
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.Data;
 
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
    public class ObjectDetectionMetrics
    {
        /// <summary>
        /// Gets or sets mAP50 which means mean Average Precision(mAP) under IOU threshold 50%.
        /// </summary>
        public float MAP50 { get; set; }
 
        /// <summary>
        /// Gets or sets mAP , which is the average of mAP from IOU threshold 50% to 95% with step 5%, equaling to the
        /// average of mAP50, mAP55, mAP60... and mAP95.
        /// </summary>
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "<Pending>")]
        public float MAP50_95 { get; set; }
 
        private class ObjectLabel
        {
            /// <summary>
            /// Gets or sets category for this sample.
            /// </summary>
            public string Category { get; set; }
 
            /// <summary>
            /// Gets or sets height of bounding box.
            /// </summary>
            public float Height { get; set; }
 
            /// <summary>
            /// Gets or sets width of bounding box.
            /// </summary>
            public float Width { get; set; }
 
            /// <summary>
            /// Gets or sets left border of bounding box.
            /// </summary>
            public float Left { get; set; }
 
            /// <summary>
            /// Gets or sets top border of bounding box.
            /// </summary>
            public float Top { get; set; }
 
            /// <summary>
            /// Gets or sets confidence score for model output.
            /// </summary>
            public float Confidence { get; set; }
        }
 
        private static List<List<ObjectLabel>> ThresholdFilter(List<List<ObjectLabel>> objectLabelListIn, float scoreThreshold)
        {
            int filterNum = 0;
            int total = 0;
            List<List<ObjectLabel>> objectLabelListOut = new();
            objectLabelListIn?.ForEach(objectLabelList =>
            {
                objectLabelListOut.Add(
                    objectLabelList.Where(objLabel => objLabel.Confidence >= scoreThreshold).ToList());
                filterNum += objectLabelList.Count - objectLabelListOut[^1].Count;
                total += objectLabelList.Count;
            });
            Console.WriteLine($"total : {total}, filtered: {filterNum}, filter ratio: {(total == 0 ? 0f : ((double)filterNum / (double)total)):P}");
            return objectLabelListOut;
        }
 
        internal static ObjectDetectionMetrics MeasureMetrics(IDataView dataView,
            DataViewSchema.Column labelCol,
            DataViewSchema.Column actualBoundingBoxColumn,
            DataViewSchema.Column predictedLabelCol,
            DataViewSchema.Column predictedBoundingBoxColumn,
            DataViewSchema.Column scoreCol
            )
        {
            var labelColEnumerable = dataView.GetColumn<VBuffer<ReadOnlyMemory<char>>>(labelCol);
            var labelSet = new HashSet<string>();
 
            foreach (var labelRow in labelColEnumerable)
            {
                foreach (var label in labelRow.DenseValues())
                    labelSet.Add(label.ToString());
            }
            var labels = labelSet.ToList();
 
            var expectedData = ActualIdvToObjLabl(dataView, labelCol, actualBoundingBoxColumn);
            var actualData = PredIdvToObjLabl(dataView, predictedLabelCol, predictedBoundingBoxColumn, scoreCol);
 
            actualData = ThresholdFilter(actualData, 0.05f);
 
            Dictionary<string, List<Tuple<float, float>>> iouDic = new();
            Dictionary<string, int> groundTruthBoxes = new();
            foreach (string label in labels)
            {
                iouDic.Add(label, new List<Tuple<float, float>>());
                groundTruthBoxes.Add(label, 0);
            }
 
            GenerateIOUDic(expectedData, actualData, iouDic, groundTruthBoxes);
 
            var evaluationMetrics = new ObjectDetectionMetrics
            {
                MAP50 = AP50(iouDic, groundTruthBoxes, 0.5f),
                MAP50_95 = AP50_95(iouDic, groundTruthBoxes)
            };
 
            return evaluationMetrics;
        }
 
        private static float AP50(Dictionary<string, List<Tuple<float, float>>> iouDic, Dictionary<string, int> groundTruthBoxes, float threshold = 0.5f)
        {
            int gt = groundTruthBoxes.Where(k => k.Value != 0).Count();
            if (gt == 0)
            {
                return 1.0f;   // no ground truth
            }
 
            float apSum = 0;
            foreach (string label in iouDic?.Keys)
            {
                apSum += LabelAP(iouDic[label], groundTruthBoxes[label], threshold);
            }
 
            return apSum / gt;
        }
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "<Pending>")]
        private static float AP50_95(Dictionary<string, List<Tuple<float, float>>> iouDic, Dictionary<string, int> groundTruthBoxes)
        {
            int gt = groundTruthBoxes.Where(k => k.Value != 0).Count();
            if (gt == 0)
            {
                return 1.0f;   // no ground truth
            }
 
            float ap5095 = 0f;
            for (float thres = 0.5f; thres < 1f; thres += 0.05f)
            {
                float apSum = 0;
                foreach (string label in iouDic.Keys)
                {
                    apSum += LabelAP(iouDic[label], groundTruthBoxes[label], thres);
                }
 
                ap5095 += apSum / gt;
            }
 
            return ap5095 / ((1f - 0.5f) / 0.05f);
        }
 
        private static float CalculateIoU(ObjectLabel bbox, ObjectLabel gt)
        {
            // min overlap site
            float xleft = Math.Max(bbox.Left, gt.Left);
            float yleft = Math.Max(bbox.Top, gt.Top);
            float xright = Math.Min(bbox.Left + bbox.Width, gt.Left + gt.Width);
            float yright = Math.Min(bbox.Top + bbox.Height, gt.Top + gt.Height);
 
            if (xleft >= xright || yleft >= yright)
            {
                return 0f;
            }
 
            float overlap = (xright - xleft) * (yright - yleft);
            float sizePredict = bbox.Width * bbox.Height;
            float sizeGroundTrue = gt.Width * gt.Height;
 
            return overlap / (sizePredict + sizeGroundTrue - overlap);
        }
 
        private static void ImageIoU(List<ObjectLabel> objectLabel, List<ObjectLabel> annotation, Dictionary<string, List<Tuple<float, float>>> iouDic, Dictionary<string, int> groundTruthBoxes)
        {
            // calculations each two iou
            List<Tuple<int, int, float>> tupleList = new(); // annotaIndex, detectIndex, iouScore
            for (int annotaIndex = 0; annotaIndex < annotation.Count; annotaIndex++)
            {
                var gt = annotation[annotaIndex];
                groundTruthBoxes[gt.Category]++; // ground truth number
                for (int detectIndex = 0; detectIndex < objectLabel.Count; detectIndex++)
                {
                    var predBox = objectLabel[detectIndex];
                    if (predBox.Category != gt.Category)
                    {
                        continue;
                    }
 
                    float iou = CalculateIoU(predBox, gt);
                    if (iou != 0f)
                    {
                        tupleList.Add(Tuple.Create(annotaIndex, detectIndex, iou));
                    }
                }
            }
 
            tupleList.Sort((x, y) => y.Item3.CompareTo(x.Item3)); // descending sort
 
            bool[] annoSign = new bool[annotation.Count]; // whether a annotation bbox is used, default false
            bool[] predSign = new bool[objectLabel.Count]; // whether a predict bbox is used, default false
            foreach (var tuple in tupleList)
            {
                if (!annoSign[tuple.Item1] && !predSign[tuple.Item2])
                {
                    iouDic[annotation[tuple.Item1].Category].Add(
                        Tuple.Create(objectLabel[tuple.Item2].Confidence, tuple.Item3));
                    annoSign[tuple.Item1] = true;
                    predSign[tuple.Item2] = true;
                }
            }
 
            // add remain predict box as 0 iou
            for (int predSignIndex = 0; predSignIndex < predSign.Length; predSignIndex++)
            {
                if (!predSign[predSignIndex])
                {
                    iouDic[objectLabel[predSignIndex].Category].Add(
                        Tuple.Create(objectLabel[predSignIndex].Confidence, 0f));
                }
            }
        }
 
        [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "<Pending>")]
        private static void GenerateIOUDic(List<List<ObjectLabel>> annotations, List<List<ObjectLabel>> objectLabelList, Dictionary<string, List<Tuple<float, float>>> iouDic, Dictionary<string, int> groundTruthBoxes)
        {
            // each image
            for (int annoIndex = 0; annoIndex < annotations?.Count; annoIndex++)
            {
                var objectLabels = objectLabelList?[annoIndex];
                var annotation = annotations[annoIndex];
                ImageIoU(objectLabels, annotation, iouDic, groundTruthBoxes);  // calculate iou of one kind
            }
 
            // sort list by label
            foreach (var detectionList in iouDic.Values)
            {
                detectionList.Sort((x, y) => y.Item1.CompareTo(x.Item1));   // compare score
            }
        }
 
        private static float LabelAP(List<Tuple<float, float>> detectionList, int groundTruthBoxesNum, float threshold = 0.5f)
        {
            if (groundTruthBoxesNum == 0)
            {
                return 0f; // should be 0 or 1 here?
            }
 
            if (detectionList?.Count == 0)
            {
                return 0f;
            }
 
            float tp = 0;   // true positive count
            Stack<Tuple<float, float>> prStack = new(); // precision and recall * groundTruthBoxesNum (= true positive)
            for (int index = 0; index < detectionList.Count; index++)
            {
                // compare iou
                if (detectionList[index].Item2 >= threshold)
                {
                    tp++;
                }
 
                prStack.Push(Tuple.Create(tp / (index + 1), tp));    // precision should be float
            }
 
            float precisionRecord = prStack.Peek().Item1;
            float recallRecord = prStack.Pop().Item2;
            float ap = 0f;  // ap value
            while (prStack.Count > 0)
            {
                Tuple<float, float> pr = prStack.Pop();
                float precision = pr.Item1;
                float recall = pr.Item2;
                if (precision > precisionRecord)
                {
                    ap += precisionRecord * (recallRecord - recall);
                    precisionRecord = precision;
                    recallRecord = recall;
                }
            }
 
            ap += precisionRecord * recallRecord;
 
            return ap / groundTruthBoxesNum;
        }
 
        private static List<List<ObjectLabel>> PredIdvToObjLabl(
            IDataView idv,
            DataViewSchema.Column predictedLabelCol,
            DataViewSchema.Column predictedBoundingBoxColumn,
            DataViewSchema.Column scoreCol)
        {
            var data = new List<List<ObjectLabel>>();
            var cursor = idv.GetRowCursor(predictedLabelCol, predictedBoundingBoxColumn, scoreCol);
 
            var predLabGet = cursor.GetGetter<VBuffer<ReadOnlyMemory<char>>>(predictedLabelCol);
            var scoreGet = cursor.GetGetter<VBuffer<float>>(scoreCol);
            var boxGet = cursor.GetGetter<VBuffer<float>>(predictedBoundingBoxColumn);
 
            VBuffer<ReadOnlyMemory<char>> predLab = default;
            VBuffer<float> score = default;
            VBuffer<float> box = default;
 
            while (cursor.MoveNext())
            {
                predLabGet(ref predLab);
                scoreGet(ref score);
                boxGet(ref box);
 
                var l = new List<ObjectLabel>();
                var boxes = box.GetValues();
                var boxIdx = 0;
 
                for (int i = 0; i < score.Length; i++)
                {
                    var obj = new ObjectLabel();
                    obj.Confidence = score.GetValues()[i];
                    obj.Category = predLab.GetValues()[i].ToString();
 
                    obj.Left = boxes[boxIdx++];
                    obj.Top = boxes[boxIdx++];
                    obj.Width = (boxes[boxIdx++] - obj.Left + 1);
                    obj.Height = (boxes[boxIdx++] - obj.Top + 1);
 
                    l.Add(obj);
                }
                data.Add(l);
            }
 
            return data;
        }
 
        private static List<List<ObjectLabel>> ActualIdvToObjLabl(
            IDataView idv,
            DataViewSchema.Column labelCol,
            DataViewSchema.Column actualBoundingBoxColumn)
        {
            var data = new List<List<ObjectLabel>>();
            var cursor = idv.GetRowCursor(labelCol, actualBoundingBoxColumn);
 
            var predLabGet = cursor.GetGetter<VBuffer<ReadOnlyMemory<char>>>(labelCol);
            var boxGet = cursor.GetGetter<VBuffer<float>>(actualBoundingBoxColumn);
 
            VBuffer<ReadOnlyMemory<char>> predLab = default;
            VBuffer<float> box = default;
 
            while (cursor.MoveNext())
            {
                predLabGet(ref predLab);
                boxGet(ref box);
 
                var l = new List<ObjectLabel>();
                var boxes = box.GetValues();
                var boxIdx = 0;
 
                for (int i = 0; i < predLab.Length; i++)
                {
                    var obj = new ObjectLabel();
                    obj.Confidence = 1f;
                    obj.Category = predLab.GetValues()[i].ToString();
 
                    obj.Left = boxes[boxIdx++];
                    obj.Top = boxes[boxIdx++];
                    obj.Width = (boxes[boxIdx++] - obj.Left + 1);
                    obj.Height = (boxes[boxIdx++] - obj.Top + 1);
 
                    l.Add(obj);
                }
                data.Add(l);
            }
 
            return data;
        }
    }
}