From d42f8a673ab7a647df4a482ad8b2f8137644bbd5 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Wed, 12 Oct 2022 12:29:23 +0200
Subject: [PATCH] overlap implemented

---
 metrics/instance_segmentation_metrics.py | 28 ++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py
index 9b5e1af..aa94bf6 100644
--- a/metrics/instance_segmentation_metrics.py
+++ b/metrics/instance_segmentation_metrics.py
@@ -24,10 +24,38 @@ class InstanceSegmentationMetrics():
         gt_classes = np.unique(las_gt.treeID)
         pred_classes = np.unique(las_pred.instance_nr)
 
+        # put x, y, z, for different classes in a dictionary
+        gt_dict = {}
+        pred_dict = {}
+        for gt_class in gt_classes:
+            gt_dict[gt_class] = np.vstack((las_gt.x[las_gt.treeID == gt_class], las_gt.y[las_gt.treeID == gt_class], las_gt.z[las_gt.treeID == gt_class])).T
+        for pred_class in pred_classes:
+            pred_dict[pred_class] = np.vstack((las_pred.x[las_pred.instance_nr == pred_class], las_pred.y[las_pred.instance_nr == pred_class], las_pred.z[las_pred.instance_nr == pred_class])).T
+
+        # compute overlap for each class
+        overlap = {}
+        for gt_class in gt_dict:
+            for pred_class in pred_dict:
+                overlap[(gt_class, pred_class)] = self.get_overlap(gt_dict[gt_class], pred_dict[pred_class])
+
         # print the number of classes in gt and pred
         logging.info('Number of classes in gt: {}'.format(len(gt_classes)))
         logging.info('Number of classes in pred: {}'.format(len(pred_classes)))
 
+        # print first 10 classes in gt and pred
+        logging.info('First 10 classes in gt: {}'.format(gt_classes[:10]))
+        logging.info('First 10 classes in pred: {}'.format(pred_classes[:10]))
+
+        # print overlap for first 10 classes
+        logging.info('Overlap for first 10 classes: {}'.format(overlap))
+
+
+    def get_overlap(self, gt, pred):
+        # compute overlap between gt and pred
+        overlap = np.intersect1d(gt, pred).shape[0]
+        # overlap = np.sum(np.all(gt[:, None] == pred, axis=-1), axis=0)
+        return overlap
+
     def get_metrics_for_all_point_clouds(self):
         # get all las files in gt and pred folders using glob
         las_gt = glob.glob(os.path.join(self.gt_folder, '*.las'))
-- 
GitLab