From 3c7a478f66971a98f5b75115c6777dc7f4dcd44d Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Mon, 9 Jan 2023 10:12:45 +0100
Subject: [PATCH] instance segmentation metrics updated to be gt oriented

---
 metrics/instance_segmentation_metrics.py | 26 ++++++++++++++++--------
 1 file changed, 18 insertions(+), 8 deletions(-)

diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py
index 34c715d..bd7e9cf 100644
--- a/metrics/instance_segmentation_metrics.py
+++ b/metrics/instance_segmentation_metrics.py
@@ -76,8 +76,18 @@ class InstanceSegmentationMetrics:
         # reshape to 1D
         ind_labels_Y = ind_labels_Y.reshape(-1) # labels from X matched to Y
 
+        # get all the indices in X which were matched to Y
+        residual_ind = np.delete(np.arange(X.shape[0]), ind.reshape(-1)) # indices of X which were not matched to Y
+
         # create a dictionary which contains Y, Y_labels and ind_labels_Y
-        dict_Y = {'Y': Y, 'Y_labels': Y_labels, 'ind_labels_Y': ind_labels_Y}
+        dict_Y = {
+            'X': X, # X is the input las file
+            'Y': Y, # Y is the instance segmented las file
+            'Y_labels': Y_labels,  # Y_labels is the instance segmented las file
+            'ind_labels_Y': ind_labels_Y, # ind_labels_Y is the labels from X matched to Y (new gt labels) 
+            'ind': ind, # ind is the indices of X which were matched to Y
+            'residual_ind': residual_ind # residual_ind is the indices of X which were not matched to Y
+            }
 
         return dict_Y
 
@@ -250,17 +260,17 @@ class InstanceSegmentationMetrics:
                 # get the indices of ind_labels_Y == dominant_label
                 ind_dominant_label = np.where(ind_labels_Y == dominant_label)[0]
 
-                # true positive is the number of points for dominant_label
+                ## true positive is the number of points for dominant_label
                 true_positive = ind_dominant_label.shape[0]
 
-                # false positive is the number of all the points of this dominant_label label minus the true positive
-                false_positive = np.where(self.dict_Y['ind_labels_Y'] == dominant_label)[0].shape[0] - true_positive
+                ## points which are within the relabelled pred but are not dominant_label
+                false_positive = ind_Y_labels_label.shape[0] - true_positive
 
-                # false negative is the number of all the points in Y_labels minus the number of points of true_positive
-                false_negative = np.where(ind_labels_Y != dominant_label)[0].shape[0] 
+                ## false negative is the number of points which are not in Y but are in X
+                false_negative = np.where(self.X_labels[self.dict_Y['residual_ind']] == dominant_label)[0].shape[0]
 
-                # true negative is the number of all the points minus the number of points of true_positive and false_positive
-                true_negative = self.dict_Y['ind_labels_Y'].shape[0] - false_negative - true_positive - false_positive
+                ## true negative 
+                true_negative = self.dict_Y['X'].shape[0] - false_negative - true_positive - false_positive
 
                 # sum all the true_positive, false_positive, false_negative, true_negative
                 sum_all = true_positive + false_positive + false_negative + true_negative
-- 
GitLab