From 3d12e09f6c9a758e6d78bd1d849b5903634633e9 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Wed, 21 Dec 2022 10:40:00 +0100
Subject: [PATCH] udpate of the metrics

---
 metrics/instance_segmentation_metrics.py      |   2 -
 ...instance_segmentation_metrics_in_folder.py |   2 +
 metrics/metrics_sem_seg.py                    | 157 +++++++++++++++---
 nibio_preprocessing/merging_and_labeling.py   |   2 +-
 4 files changed, 133 insertions(+), 30 deletions(-)

diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py
index a7b891b..83e97a7 100644
--- a/metrics/instance_segmentation_metrics.py
+++ b/metrics/instance_segmentation_metrics.py
@@ -1,5 +1,3 @@
-
-
 import argparse
 import os
 import laspy
diff --git a/metrics/instance_segmentation_metrics_in_folder.py b/metrics/instance_segmentation_metrics_in_folder.py
index f299e50..9bc000a 100644
--- a/metrics/instance_segmentation_metrics_in_folder.py
+++ b/metrics/instance_segmentation_metrics_in_folder.py
@@ -131,6 +131,8 @@ class InstanceSegmentationMetricsInFolder():
             with open(save_to_csv_path, 'w') as csv_file:
                 writer = csv.writer(csv_file)
                 for key, value in mean_metrics.items():
+                    # round the value to 3 decimal places
+                    value = round(value, 3)
                     writer.writerow([key, value])
                     
         if self.verbose:
diff --git a/metrics/metrics_sem_seg.py b/metrics/metrics_sem_seg.py
index 699bda9..19a737a 100644
--- a/metrics/metrics_sem_seg.py
+++ b/metrics/metrics_sem_seg.py
@@ -1,9 +1,11 @@
 import argparse
 import os
+from joblib import Parallel, delayed
 import laspy
 from matplotlib import pyplot as plt
 import numpy as np
-from sklearn.neighbors import NearestNeighbors
+# from sklearn.neighbors import NearestNeighbors
+from sklearn.neighbors import KDTree
 from tqdm import tqdm
 from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, f1_score, precision_score, recall_score
 
@@ -23,16 +25,17 @@ class MetricSemSeg:
         self, 
         gt_folder,
         pred_folder,
+        adjust_gt_data=False,
         plot_confusion_matrix=False, 
         verbose=False
         ):
         self.gt_folder = gt_folder
         self.pred_folder = pred_folder
+        self.adjust_gt_data = adjust_gt_data # if True, the ground truth data will be adjusted to match the predicted data
+        # this means removing data from 0 class which are considered as not important and should not be considered in the metrics
         self.plot_confusion_matrix = plot_confusion_matrix
         self.verbose = verbose
 
-
-
     def get_file_name_list(self):
 
         # get list of the original point clouds
@@ -44,6 +47,11 @@ class MetricSemSeg:
         # sort the list in place
         file_name_list_original.sort()
 
+        # get the list of the core names of the original point clouds
+        # file_name_list_original_core = []
+        # for file_name in file_name_list_original:
+        #     file_name_list_original_core.append(os.path.basename(file_name).split('.')[0])
+
         # get list of the predicted point clouds
         file_name_list_predicted = []
         for file in os.listdir(self.pred_folder):
@@ -53,10 +61,6 @@ class MetricSemSeg:
         # sort the list in place
         file_name_list_predicted.sort()
 
-        if self.verbose:
-            print("file_name_list_original: ", file_name_list_original)
-            print("file_name_list_predicted: ", file_name_list_predicted)
-
         # check if the number of files in the two folders is the same
         if len(file_name_list_original) != len(file_name_list_predicted):
             raise Exception('The number of files in the two folders is not the same.')
@@ -84,20 +88,44 @@ class MetricSemSeg:
 
         return file_name_list
 
-    def get_labels_from_point_file(self, file_name):
+    def get_labels_from_point_file(self, file_name, predicted=False):
         """
         This function returns the labels of a point cloud file.
         """
         point_cloud = laspy.read(file_name)
-        labels = point_cloud.label
+        if self.adjust_gt_data and not predicted:
+            if self.verbose:
+                print("adjusting ground truth data")
+            labels = point_cloud.label
+            labels = labels[labels != 0]
+            labels = labels - 1
+        else:
+            labels = point_cloud.label
+
+        # convert to int
+        labels = labels.astype(int)
+
         return labels
 
-    def get_xyz_from_point_file(self, file_name):
+    def get_xyz_from_point_file(self, file_name, predicted=False):
         """
         This function returns the xyz coordinates of a point cloud file.
         """
         point_cloud = laspy.read(file_name)
-        xyz = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose()
+        if self.adjust_gt_data and not predicted:
+            if self.verbose:
+                print("adjusting ground truth data")
+            # find all the points xyz with label 0
+            labels = point_cloud.label
+            xyz = np.vstack((
+                point_cloud.x[labels !=0], 
+                point_cloud.y[labels !=0], 
+                point_cloud.z[labels !=0]
+                )).transpose()
+
+        else:
+            xyz = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose()
+
         return xyz
 
     def get_metrics_for_single_file(self, file_name_original, file_name_predicted):
@@ -106,8 +134,14 @@ class MetricSemSeg:
         """
 
         # get labels
-        labels_predicted = self.get_labels_from_point_file(file_name_predicted)
-        labels_original = self.get_labels_from_point_file(file_name_original) - 1
+        labels_predicted = self.get_labels_from_point_file(file_name_predicted, predicted=True)
+        labels_original = self.get_labels_from_point_file(file_name_original)
+
+        # find and print the ranges of the labels
+        if self.verbose:
+            print("labels_predicted range: ", np.min(labels_predicted), np.max(labels_predicted))
+            print("labels_original range: ", np.min(labels_original), np.max(labels_original))
+
         # print shape of labels
         if self.verbose:
             print("labels_predicted.shape: ", labels_predicted.shape)
@@ -115,19 +149,35 @@ class MetricSemSeg:
 
 
         # get points
+        xyz_predicted = self.get_xyz_from_point_file(file_name_predicted, predicted=True)
         xyz_original = self.get_xyz_from_point_file(file_name_original)
-        xyz_predicted = self.get_xyz_from_point_file(file_name_predicted)
+
 
         # find the closest point in the original point cloud for each point in the predicted point cloud using the euclidean distance using knn
 
-        nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(xyz_original)
-        distances, indices = nbrs.kneighbors(xyz_predicted)
+        # nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(xyz_original)
+        # distances, indices = nbrs.kneighbors(xyz_predicted)
+
+        tree = KDTree(xyz_original, leaf_size=50, metric='euclidean')       
+        # query the tree for Y     
+        indices = tree.query(xyz_predicted, k=1, return_distance=False)   
 
         # get the labels of the closest points
         labels_original_closest = labels_original[indices]
 
         # get the confusion matrix
-        conf_matrix = np.round(confusion_matrix(labels_original_closest, labels_predicted, normalize='true'), decimals=2)
+        conf_matrix = np.round(confusion_matrix(labels_original_closest, labels_predicted, normalize='false'), decimals=2)
+
+        # if conf_matrix.shape[0] == 3 add diagonal elements to make it 4x4 at dimension 2
+        if conf_matrix.shape[0] == 3:
+            if self.verbose:
+                print("conf_matrix.shape[0] == 3 expanding to 4x4 at dimension 2")
+            conf_matrix = np.insert(conf_matrix, 2, 0, axis=1)
+            conf_matrix = np.insert(conf_matrix, 2, 0, axis=0)
+
+        # print the confusion matrix shape
+        if self.verbose:
+            print("conf_matrix.shape: ", conf_matrix.shape)
 
         # get picture of the confusion matrix
         if self.plot_confusion_matrix:
@@ -184,12 +234,21 @@ class MetricSemSeg:
         f1_per_class_list = []
 
         # loop over all files
-        for file_name_original, file_name_predicted in tqdm(file_name_list):
-            if self.verbose:
-                print("file_name_original: ", file_name_original)
-                print("file_name_predicted: ", file_name_predicted)
-                
-            results = self.get_metrics_for_single_file(file_name_original, file_name_predicted)
+        # results = []
+        # for file_name_original, file_name_predicted in tqdm(file_name_list):
+        #     if self.verbose:
+        #         print("file_name_original: ", file_name_original)
+        #         print("file_name_predicted: ", file_name_predicted)
+
+        #     results.append(self.get_metrics_for_single_file(file_name_original, file_name_predicted))
+
+        # parallelize the computation
+        results = Parallel(n_jobs=-1, verbose=0)(
+            delayed(self.get_metrics_for_single_file)(file_name_original, file_name_predicted) for file_name_original, file_name_predicted in file_name_list
+        )
+
+        # extract the results from the dictionary
+        for results in results:
             conf_matrix_list.append(results['confusion_matrix'])
             precision_list.append(results['precision'])
             recall_list.append(results['recall'])
@@ -207,9 +266,46 @@ class MetricSemSeg:
         f1_mean = np.mean(f1_list)
 
         # compute the mean of the precision, recall and f1 score per class
-        precision_per_class_mean = np.mean(precision_per_class_list, axis=0)
-        recall_per_class_mean = np.mean(recall_per_class_list, axis=0)
-        f1_per_class_mean = np.mean(f1_per_class_list, axis=0)
+        # compute separately for items which have 3 and 4 classes
+        # create a separate list for item with 3 classes
+        precision_per_class_list_3 = []
+        recall_per_class_list_3 = []
+        f1_per_class_list_3 = []
+        # create a separate list for item with 4 classes
+        precision_per_class_list_4 = []
+        recall_per_class_list_4 = []
+        f1_per_class_list_4 = []
+        # loop over all items
+        for i in range(len(precision_per_class_list)):
+            # check if the item has 3 classes
+            if len(precision_per_class_list[i]) == 3:
+                precision_per_class_list_3.append(precision_per_class_list[i])
+                recall_per_class_list_3.append(recall_per_class_list[i])
+                f1_per_class_list_3.append(f1_per_class_list[i])
+            # check if the item has 4 classes
+            elif len(precision_per_class_list[i]) == 4:
+                precision_per_class_list_4.append(precision_per_class_list[i])
+                recall_per_class_list_4.append(recall_per_class_list[i])
+                f1_per_class_list_4.append(f1_per_class_list[i])
+            else:
+                print("ERROR: the number of classes is neither 3 nor 4")
+                exit()
+
+
+        # compute the mean of the precision, recall and f1 score per class for items with 3 classes
+        precision_per_class_mean_3 = np.mean(precision_per_class_list_3, axis=0)
+        recall_per_class_mean_3 = np.mean(recall_per_class_list_3, axis=0)
+        f1_per_class_mean_3 = np.mean(f1_per_class_list_3, axis=0)
+
+        # compute the mean of the precision, recall and f1 score per class for items with 4 classes
+        precision_per_class_mean_4 = np.mean(precision_per_class_list_4, axis=0)
+        recall_per_class_mean_4 = np.mean(recall_per_class_list_4, axis=0)
+        f1_per_class_mean_4 = np.mean(f1_per_class_list_4, axis=0)
+
+        # compute the mean of the precision, recall and f1 score per class
+        precision_per_class_mean = np.mean(np.concatenate((precision_per_class_mean_3, precision_per_class_mean_4), axis=0), axis=0)
+        recall_per_class_mean = np.mean(np.concatenate((recall_per_class_mean_3, recall_per_class_mean_4), axis=0), axis=0)
+        f1_per_class_mean = np.mean(np.concatenate((f1_per_class_mean_3, f1_per_class_mean_4), axis=0), axis=0)
 
         # put all the results in a dictionary
         results = {
@@ -241,12 +337,19 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--path_original', type=str, default='data/original', help='path to the original point clouds directory')
     parser.add_argument('--path_predicted', type=str, default='data/predicted', help='path to the predicted point clouds directory')
+    parser.add_argument('--adjust_gt_data', action='store_true', default=False,  help='adjust the ground truth data')
     parser.add_argument('--plot_confusion_matrix', action='store_true', default=False,  help='plot the confusion matrix')
     parser.add_argument('--verbose', action='store_true', default=False,  help='print more information')
     args = parser.parse_args()
 
     # create an instance of the class
-    metrics = MetricSemSeg(args.path_original, args.path_predicted, args.plot_confusion_matrix, args.verbose)
+    metrics = MetricSemSeg(
+        args.path_original, 
+        args.path_predicted, 
+        args.adjust_gt_data,
+        args.plot_confusion_matrix, 
+        args.verbose
+        )
 
     # get the metrics
     results = metrics.main()
diff --git a/nibio_preprocessing/merging_and_labeling.py b/nibio_preprocessing/merging_and_labeling.py
index f452ac3..509752a 100644
--- a/nibio_preprocessing/merging_and_labeling.py
+++ b/nibio_preprocessing/merging_and_labeling.py
@@ -80,7 +80,7 @@ def merge_ply_files(data_folder, output_file='output_instance_segmented.ply'):
 
     # print where the file is saved
     logging.info("The file is saved in: " + os.path.join(data_folder, output_file))
-    logging.info("Merging was done for {} number of files".format(len(tags)))
+    logging.info("Merging was done for {} files".format(len(tags)))
 
     pipeline = pdal.Pipeline(json.dumps(data))
     pipeline.execute()
-- 
GitLab