From 3d12e09f6c9a758e6d78bd1d849b5903634633e9 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Wed, 21 Dec 2022 10:40:00 +0100 Subject: [PATCH] udpate of the metrics --- metrics/instance_segmentation_metrics.py | 2 - ...instance_segmentation_metrics_in_folder.py | 2 + metrics/metrics_sem_seg.py | 157 +++++++++++++++--- nibio_preprocessing/merging_and_labeling.py | 2 +- 4 files changed, 133 insertions(+), 30 deletions(-) diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py index a7b891b..83e97a7 100644 --- a/metrics/instance_segmentation_metrics.py +++ b/metrics/instance_segmentation_metrics.py @@ -1,5 +1,3 @@ - - import argparse import os import laspy diff --git a/metrics/instance_segmentation_metrics_in_folder.py b/metrics/instance_segmentation_metrics_in_folder.py index f299e50..9bc000a 100644 --- a/metrics/instance_segmentation_metrics_in_folder.py +++ b/metrics/instance_segmentation_metrics_in_folder.py @@ -131,6 +131,8 @@ class InstanceSegmentationMetricsInFolder(): with open(save_to_csv_path, 'w') as csv_file: writer = csv.writer(csv_file) for key, value in mean_metrics.items(): + # round the value to 3 decimal places + value = round(value, 3) writer.writerow([key, value]) if self.verbose: diff --git a/metrics/metrics_sem_seg.py b/metrics/metrics_sem_seg.py index 699bda9..19a737a 100644 --- a/metrics/metrics_sem_seg.py +++ b/metrics/metrics_sem_seg.py @@ -1,9 +1,11 @@ import argparse import os +from joblib import Parallel, delayed import laspy from matplotlib import pyplot as plt import numpy as np -from sklearn.neighbors import NearestNeighbors +# from sklearn.neighbors import NearestNeighbors +from sklearn.neighbors import KDTree from tqdm import tqdm from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, f1_score, precision_score, recall_score @@ -23,16 +25,17 @@ class MetricSemSeg: self, gt_folder, pred_folder, + adjust_gt_data=False, plot_confusion_matrix=False, verbose=False ): self.gt_folder = gt_folder self.pred_folder = pred_folder + self.adjust_gt_data = adjust_gt_data # if True, the ground truth data will be adjusted to match the predicted data + # this means removing data from 0 class which are considered as not important and should not be considered in the metrics self.plot_confusion_matrix = plot_confusion_matrix self.verbose = verbose - - def get_file_name_list(self): # get list of the original point clouds @@ -44,6 +47,11 @@ class MetricSemSeg: # sort the list in place file_name_list_original.sort() + # get the list of the core names of the original point clouds + # file_name_list_original_core = [] + # for file_name in file_name_list_original: + # file_name_list_original_core.append(os.path.basename(file_name).split('.')[0]) + # get list of the predicted point clouds file_name_list_predicted = [] for file in os.listdir(self.pred_folder): @@ -53,10 +61,6 @@ class MetricSemSeg: # sort the list in place file_name_list_predicted.sort() - if self.verbose: - print("file_name_list_original: ", file_name_list_original) - print("file_name_list_predicted: ", file_name_list_predicted) - # check if the number of files in the two folders is the same if len(file_name_list_original) != len(file_name_list_predicted): raise Exception('The number of files in the two folders is not the same.') @@ -84,20 +88,44 @@ class MetricSemSeg: return file_name_list - def get_labels_from_point_file(self, file_name): + def get_labels_from_point_file(self, file_name, predicted=False): """ This function returns the labels of a point cloud file. """ point_cloud = laspy.read(file_name) - labels = point_cloud.label + if self.adjust_gt_data and not predicted: + if self.verbose: + print("adjusting ground truth data") + labels = point_cloud.label + labels = labels[labels != 0] + labels = labels - 1 + else: + labels = point_cloud.label + + # convert to int + labels = labels.astype(int) + return labels - def get_xyz_from_point_file(self, file_name): + def get_xyz_from_point_file(self, file_name, predicted=False): """ This function returns the xyz coordinates of a point cloud file. """ point_cloud = laspy.read(file_name) - xyz = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose() + if self.adjust_gt_data and not predicted: + if self.verbose: + print("adjusting ground truth data") + # find all the points xyz with label 0 + labels = point_cloud.label + xyz = np.vstack(( + point_cloud.x[labels !=0], + point_cloud.y[labels !=0], + point_cloud.z[labels !=0] + )).transpose() + + else: + xyz = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose() + return xyz def get_metrics_for_single_file(self, file_name_original, file_name_predicted): @@ -106,8 +134,14 @@ class MetricSemSeg: """ # get labels - labels_predicted = self.get_labels_from_point_file(file_name_predicted) - labels_original = self.get_labels_from_point_file(file_name_original) - 1 + labels_predicted = self.get_labels_from_point_file(file_name_predicted, predicted=True) + labels_original = self.get_labels_from_point_file(file_name_original) + + # find and print the ranges of the labels + if self.verbose: + print("labels_predicted range: ", np.min(labels_predicted), np.max(labels_predicted)) + print("labels_original range: ", np.min(labels_original), np.max(labels_original)) + # print shape of labels if self.verbose: print("labels_predicted.shape: ", labels_predicted.shape) @@ -115,19 +149,35 @@ class MetricSemSeg: # get points + xyz_predicted = self.get_xyz_from_point_file(file_name_predicted, predicted=True) xyz_original = self.get_xyz_from_point_file(file_name_original) - xyz_predicted = self.get_xyz_from_point_file(file_name_predicted) + # find the closest point in the original point cloud for each point in the predicted point cloud using the euclidean distance using knn - nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(xyz_original) - distances, indices = nbrs.kneighbors(xyz_predicted) + # nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(xyz_original) + # distances, indices = nbrs.kneighbors(xyz_predicted) + + tree = KDTree(xyz_original, leaf_size=50, metric='euclidean') + # query the tree for Y + indices = tree.query(xyz_predicted, k=1, return_distance=False) # get the labels of the closest points labels_original_closest = labels_original[indices] # get the confusion matrix - conf_matrix = np.round(confusion_matrix(labels_original_closest, labels_predicted, normalize='true'), decimals=2) + conf_matrix = np.round(confusion_matrix(labels_original_closest, labels_predicted, normalize='false'), decimals=2) + + # if conf_matrix.shape[0] == 3 add diagonal elements to make it 4x4 at dimension 2 + if conf_matrix.shape[0] == 3: + if self.verbose: + print("conf_matrix.shape[0] == 3 expanding to 4x4 at dimension 2") + conf_matrix = np.insert(conf_matrix, 2, 0, axis=1) + conf_matrix = np.insert(conf_matrix, 2, 0, axis=0) + + # print the confusion matrix shape + if self.verbose: + print("conf_matrix.shape: ", conf_matrix.shape) # get picture of the confusion matrix if self.plot_confusion_matrix: @@ -184,12 +234,21 @@ class MetricSemSeg: f1_per_class_list = [] # loop over all files - for file_name_original, file_name_predicted in tqdm(file_name_list): - if self.verbose: - print("file_name_original: ", file_name_original) - print("file_name_predicted: ", file_name_predicted) - - results = self.get_metrics_for_single_file(file_name_original, file_name_predicted) + # results = [] + # for file_name_original, file_name_predicted in tqdm(file_name_list): + # if self.verbose: + # print("file_name_original: ", file_name_original) + # print("file_name_predicted: ", file_name_predicted) + + # results.append(self.get_metrics_for_single_file(file_name_original, file_name_predicted)) + + # parallelize the computation + results = Parallel(n_jobs=-1, verbose=0)( + delayed(self.get_metrics_for_single_file)(file_name_original, file_name_predicted) for file_name_original, file_name_predicted in file_name_list + ) + + # extract the results from the dictionary + for results in results: conf_matrix_list.append(results['confusion_matrix']) precision_list.append(results['precision']) recall_list.append(results['recall']) @@ -207,9 +266,46 @@ class MetricSemSeg: f1_mean = np.mean(f1_list) # compute the mean of the precision, recall and f1 score per class - precision_per_class_mean = np.mean(precision_per_class_list, axis=0) - recall_per_class_mean = np.mean(recall_per_class_list, axis=0) - f1_per_class_mean = np.mean(f1_per_class_list, axis=0) + # compute separately for items which have 3 and 4 classes + # create a separate list for item with 3 classes + precision_per_class_list_3 = [] + recall_per_class_list_3 = [] + f1_per_class_list_3 = [] + # create a separate list for item with 4 classes + precision_per_class_list_4 = [] + recall_per_class_list_4 = [] + f1_per_class_list_4 = [] + # loop over all items + for i in range(len(precision_per_class_list)): + # check if the item has 3 classes + if len(precision_per_class_list[i]) == 3: + precision_per_class_list_3.append(precision_per_class_list[i]) + recall_per_class_list_3.append(recall_per_class_list[i]) + f1_per_class_list_3.append(f1_per_class_list[i]) + # check if the item has 4 classes + elif len(precision_per_class_list[i]) == 4: + precision_per_class_list_4.append(precision_per_class_list[i]) + recall_per_class_list_4.append(recall_per_class_list[i]) + f1_per_class_list_4.append(f1_per_class_list[i]) + else: + print("ERROR: the number of classes is neither 3 nor 4") + exit() + + + # compute the mean of the precision, recall and f1 score per class for items with 3 classes + precision_per_class_mean_3 = np.mean(precision_per_class_list_3, axis=0) + recall_per_class_mean_3 = np.mean(recall_per_class_list_3, axis=0) + f1_per_class_mean_3 = np.mean(f1_per_class_list_3, axis=0) + + # compute the mean of the precision, recall and f1 score per class for items with 4 classes + precision_per_class_mean_4 = np.mean(precision_per_class_list_4, axis=0) + recall_per_class_mean_4 = np.mean(recall_per_class_list_4, axis=0) + f1_per_class_mean_4 = np.mean(f1_per_class_list_4, axis=0) + + # compute the mean of the precision, recall and f1 score per class + precision_per_class_mean = np.mean(np.concatenate((precision_per_class_mean_3, precision_per_class_mean_4), axis=0), axis=0) + recall_per_class_mean = np.mean(np.concatenate((recall_per_class_mean_3, recall_per_class_mean_4), axis=0), axis=0) + f1_per_class_mean = np.mean(np.concatenate((f1_per_class_mean_3, f1_per_class_mean_4), axis=0), axis=0) # put all the results in a dictionary results = { @@ -241,12 +337,19 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--path_original', type=str, default='data/original', help='path to the original point clouds directory') parser.add_argument('--path_predicted', type=str, default='data/predicted', help='path to the predicted point clouds directory') + parser.add_argument('--adjust_gt_data', action='store_true', default=False, help='adjust the ground truth data') parser.add_argument('--plot_confusion_matrix', action='store_true', default=False, help='plot the confusion matrix') parser.add_argument('--verbose', action='store_true', default=False, help='print more information') args = parser.parse_args() # create an instance of the class - metrics = MetricSemSeg(args.path_original, args.path_predicted, args.plot_confusion_matrix, args.verbose) + metrics = MetricSemSeg( + args.path_original, + args.path_predicted, + args.adjust_gt_data, + args.plot_confusion_matrix, + args.verbose + ) # get the metrics results = metrics.main() diff --git a/nibio_preprocessing/merging_and_labeling.py b/nibio_preprocessing/merging_and_labeling.py index f452ac3..509752a 100644 --- a/nibio_preprocessing/merging_and_labeling.py +++ b/nibio_preprocessing/merging_and_labeling.py @@ -80,7 +80,7 @@ def merge_ply_files(data_folder, output_file='output_instance_segmented.ply'): # print where the file is saved logging.info("The file is saved in: " + os.path.join(data_folder, output_file)) - logging.info("Merging was done for {} number of files".format(len(tags))) + logging.info("Merging was done for {} files".format(len(tags))) pipeline = pdal.Pipeline(json.dumps(data)) pipeline.execute() -- GitLab