Skip to content
Snippets Groups Projects
Commit 3d12e09f authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

udpate of the metrics

parent a71a5ef3
No related branches found
No related tags found
No related merge requests found
import argparse import argparse
import os import os
import laspy import laspy
......
...@@ -131,6 +131,8 @@ class InstanceSegmentationMetricsInFolder(): ...@@ -131,6 +131,8 @@ class InstanceSegmentationMetricsInFolder():
with open(save_to_csv_path, 'w') as csv_file: with open(save_to_csv_path, 'w') as csv_file:
writer = csv.writer(csv_file) writer = csv.writer(csv_file)
for key, value in mean_metrics.items(): for key, value in mean_metrics.items():
# round the value to 3 decimal places
value = round(value, 3)
writer.writerow([key, value]) writer.writerow([key, value])
if self.verbose: if self.verbose:
......
import argparse import argparse
import os import os
from joblib import Parallel, delayed
import laspy import laspy
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import numpy as np import numpy as np
from sklearn.neighbors import NearestNeighbors # from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import KDTree
from tqdm import tqdm from tqdm import tqdm
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, f1_score, precision_score, recall_score from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, f1_score, precision_score, recall_score
...@@ -23,16 +25,17 @@ class MetricSemSeg: ...@@ -23,16 +25,17 @@ class MetricSemSeg:
self, self,
gt_folder, gt_folder,
pred_folder, pred_folder,
adjust_gt_data=False,
plot_confusion_matrix=False, plot_confusion_matrix=False,
verbose=False verbose=False
): ):
self.gt_folder = gt_folder self.gt_folder = gt_folder
self.pred_folder = pred_folder self.pred_folder = pred_folder
self.adjust_gt_data = adjust_gt_data # if True, the ground truth data will be adjusted to match the predicted data
# this means removing data from 0 class which are considered as not important and should not be considered in the metrics
self.plot_confusion_matrix = plot_confusion_matrix self.plot_confusion_matrix = plot_confusion_matrix
self.verbose = verbose self.verbose = verbose
def get_file_name_list(self): def get_file_name_list(self):
# get list of the original point clouds # get list of the original point clouds
...@@ -44,6 +47,11 @@ class MetricSemSeg: ...@@ -44,6 +47,11 @@ class MetricSemSeg:
# sort the list in place # sort the list in place
file_name_list_original.sort() file_name_list_original.sort()
# get the list of the core names of the original point clouds
# file_name_list_original_core = []
# for file_name in file_name_list_original:
# file_name_list_original_core.append(os.path.basename(file_name).split('.')[0])
# get list of the predicted point clouds # get list of the predicted point clouds
file_name_list_predicted = [] file_name_list_predicted = []
for file in os.listdir(self.pred_folder): for file in os.listdir(self.pred_folder):
...@@ -53,10 +61,6 @@ class MetricSemSeg: ...@@ -53,10 +61,6 @@ class MetricSemSeg:
# sort the list in place # sort the list in place
file_name_list_predicted.sort() file_name_list_predicted.sort()
if self.verbose:
print("file_name_list_original: ", file_name_list_original)
print("file_name_list_predicted: ", file_name_list_predicted)
# check if the number of files in the two folders is the same # check if the number of files in the two folders is the same
if len(file_name_list_original) != len(file_name_list_predicted): if len(file_name_list_original) != len(file_name_list_predicted):
raise Exception('The number of files in the two folders is not the same.') raise Exception('The number of files in the two folders is not the same.')
...@@ -84,20 +88,44 @@ class MetricSemSeg: ...@@ -84,20 +88,44 @@ class MetricSemSeg:
return file_name_list return file_name_list
def get_labels_from_point_file(self, file_name): def get_labels_from_point_file(self, file_name, predicted=False):
""" """
This function returns the labels of a point cloud file. This function returns the labels of a point cloud file.
""" """
point_cloud = laspy.read(file_name) point_cloud = laspy.read(file_name)
labels = point_cloud.label if self.adjust_gt_data and not predicted:
if self.verbose:
print("adjusting ground truth data")
labels = point_cloud.label
labels = labels[labels != 0]
labels = labels - 1
else:
labels = point_cloud.label
# convert to int
labels = labels.astype(int)
return labels return labels
def get_xyz_from_point_file(self, file_name): def get_xyz_from_point_file(self, file_name, predicted=False):
""" """
This function returns the xyz coordinates of a point cloud file. This function returns the xyz coordinates of a point cloud file.
""" """
point_cloud = laspy.read(file_name) point_cloud = laspy.read(file_name)
xyz = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose() if self.adjust_gt_data and not predicted:
if self.verbose:
print("adjusting ground truth data")
# find all the points xyz with label 0
labels = point_cloud.label
xyz = np.vstack((
point_cloud.x[labels !=0],
point_cloud.y[labels !=0],
point_cloud.z[labels !=0]
)).transpose()
else:
xyz = np.vstack((point_cloud.x, point_cloud.y, point_cloud.z)).transpose()
return xyz return xyz
def get_metrics_for_single_file(self, file_name_original, file_name_predicted): def get_metrics_for_single_file(self, file_name_original, file_name_predicted):
...@@ -106,8 +134,14 @@ class MetricSemSeg: ...@@ -106,8 +134,14 @@ class MetricSemSeg:
""" """
# get labels # get labels
labels_predicted = self.get_labels_from_point_file(file_name_predicted) labels_predicted = self.get_labels_from_point_file(file_name_predicted, predicted=True)
labels_original = self.get_labels_from_point_file(file_name_original) - 1 labels_original = self.get_labels_from_point_file(file_name_original)
# find and print the ranges of the labels
if self.verbose:
print("labels_predicted range: ", np.min(labels_predicted), np.max(labels_predicted))
print("labels_original range: ", np.min(labels_original), np.max(labels_original))
# print shape of labels # print shape of labels
if self.verbose: if self.verbose:
print("labels_predicted.shape: ", labels_predicted.shape) print("labels_predicted.shape: ", labels_predicted.shape)
...@@ -115,19 +149,35 @@ class MetricSemSeg: ...@@ -115,19 +149,35 @@ class MetricSemSeg:
# get points # get points
xyz_predicted = self.get_xyz_from_point_file(file_name_predicted, predicted=True)
xyz_original = self.get_xyz_from_point_file(file_name_original) xyz_original = self.get_xyz_from_point_file(file_name_original)
xyz_predicted = self.get_xyz_from_point_file(file_name_predicted)
# find the closest point in the original point cloud for each point in the predicted point cloud using the euclidean distance using knn # find the closest point in the original point cloud for each point in the predicted point cloud using the euclidean distance using knn
nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(xyz_original) # nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(xyz_original)
distances, indices = nbrs.kneighbors(xyz_predicted) # distances, indices = nbrs.kneighbors(xyz_predicted)
tree = KDTree(xyz_original, leaf_size=50, metric='euclidean')
# query the tree for Y
indices = tree.query(xyz_predicted, k=1, return_distance=False)
# get the labels of the closest points # get the labels of the closest points
labels_original_closest = labels_original[indices] labels_original_closest = labels_original[indices]
# get the confusion matrix # get the confusion matrix
conf_matrix = np.round(confusion_matrix(labels_original_closest, labels_predicted, normalize='true'), decimals=2) conf_matrix = np.round(confusion_matrix(labels_original_closest, labels_predicted, normalize='false'), decimals=2)
# if conf_matrix.shape[0] == 3 add diagonal elements to make it 4x4 at dimension 2
if conf_matrix.shape[0] == 3:
if self.verbose:
print("conf_matrix.shape[0] == 3 expanding to 4x4 at dimension 2")
conf_matrix = np.insert(conf_matrix, 2, 0, axis=1)
conf_matrix = np.insert(conf_matrix, 2, 0, axis=0)
# print the confusion matrix shape
if self.verbose:
print("conf_matrix.shape: ", conf_matrix.shape)
# get picture of the confusion matrix # get picture of the confusion matrix
if self.plot_confusion_matrix: if self.plot_confusion_matrix:
...@@ -184,12 +234,21 @@ class MetricSemSeg: ...@@ -184,12 +234,21 @@ class MetricSemSeg:
f1_per_class_list = [] f1_per_class_list = []
# loop over all files # loop over all files
for file_name_original, file_name_predicted in tqdm(file_name_list): # results = []
if self.verbose: # for file_name_original, file_name_predicted in tqdm(file_name_list):
print("file_name_original: ", file_name_original) # if self.verbose:
print("file_name_predicted: ", file_name_predicted) # print("file_name_original: ", file_name_original)
# print("file_name_predicted: ", file_name_predicted)
results = self.get_metrics_for_single_file(file_name_original, file_name_predicted)
# results.append(self.get_metrics_for_single_file(file_name_original, file_name_predicted))
# parallelize the computation
results = Parallel(n_jobs=-1, verbose=0)(
delayed(self.get_metrics_for_single_file)(file_name_original, file_name_predicted) for file_name_original, file_name_predicted in file_name_list
)
# extract the results from the dictionary
for results in results:
conf_matrix_list.append(results['confusion_matrix']) conf_matrix_list.append(results['confusion_matrix'])
precision_list.append(results['precision']) precision_list.append(results['precision'])
recall_list.append(results['recall']) recall_list.append(results['recall'])
...@@ -207,9 +266,46 @@ class MetricSemSeg: ...@@ -207,9 +266,46 @@ class MetricSemSeg:
f1_mean = np.mean(f1_list) f1_mean = np.mean(f1_list)
# compute the mean of the precision, recall and f1 score per class # compute the mean of the precision, recall and f1 score per class
precision_per_class_mean = np.mean(precision_per_class_list, axis=0) # compute separately for items which have 3 and 4 classes
recall_per_class_mean = np.mean(recall_per_class_list, axis=0) # create a separate list for item with 3 classes
f1_per_class_mean = np.mean(f1_per_class_list, axis=0) precision_per_class_list_3 = []
recall_per_class_list_3 = []
f1_per_class_list_3 = []
# create a separate list for item with 4 classes
precision_per_class_list_4 = []
recall_per_class_list_4 = []
f1_per_class_list_4 = []
# loop over all items
for i in range(len(precision_per_class_list)):
# check if the item has 3 classes
if len(precision_per_class_list[i]) == 3:
precision_per_class_list_3.append(precision_per_class_list[i])
recall_per_class_list_3.append(recall_per_class_list[i])
f1_per_class_list_3.append(f1_per_class_list[i])
# check if the item has 4 classes
elif len(precision_per_class_list[i]) == 4:
precision_per_class_list_4.append(precision_per_class_list[i])
recall_per_class_list_4.append(recall_per_class_list[i])
f1_per_class_list_4.append(f1_per_class_list[i])
else:
print("ERROR: the number of classes is neither 3 nor 4")
exit()
# compute the mean of the precision, recall and f1 score per class for items with 3 classes
precision_per_class_mean_3 = np.mean(precision_per_class_list_3, axis=0)
recall_per_class_mean_3 = np.mean(recall_per_class_list_3, axis=0)
f1_per_class_mean_3 = np.mean(f1_per_class_list_3, axis=0)
# compute the mean of the precision, recall and f1 score per class for items with 4 classes
precision_per_class_mean_4 = np.mean(precision_per_class_list_4, axis=0)
recall_per_class_mean_4 = np.mean(recall_per_class_list_4, axis=0)
f1_per_class_mean_4 = np.mean(f1_per_class_list_4, axis=0)
# compute the mean of the precision, recall and f1 score per class
precision_per_class_mean = np.mean(np.concatenate((precision_per_class_mean_3, precision_per_class_mean_4), axis=0), axis=0)
recall_per_class_mean = np.mean(np.concatenate((recall_per_class_mean_3, recall_per_class_mean_4), axis=0), axis=0)
f1_per_class_mean = np.mean(np.concatenate((f1_per_class_mean_3, f1_per_class_mean_4), axis=0), axis=0)
# put all the results in a dictionary # put all the results in a dictionary
results = { results = {
...@@ -241,12 +337,19 @@ if __name__ == '__main__': ...@@ -241,12 +337,19 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--path_original', type=str, default='data/original', help='path to the original point clouds directory') parser.add_argument('--path_original', type=str, default='data/original', help='path to the original point clouds directory')
parser.add_argument('--path_predicted', type=str, default='data/predicted', help='path to the predicted point clouds directory') parser.add_argument('--path_predicted', type=str, default='data/predicted', help='path to the predicted point clouds directory')
parser.add_argument('--adjust_gt_data', action='store_true', default=False, help='adjust the ground truth data')
parser.add_argument('--plot_confusion_matrix', action='store_true', default=False, help='plot the confusion matrix') parser.add_argument('--plot_confusion_matrix', action='store_true', default=False, help='plot the confusion matrix')
parser.add_argument('--verbose', action='store_true', default=False, help='print more information') parser.add_argument('--verbose', action='store_true', default=False, help='print more information')
args = parser.parse_args() args = parser.parse_args()
# create an instance of the class # create an instance of the class
metrics = MetricSemSeg(args.path_original, args.path_predicted, args.plot_confusion_matrix, args.verbose) metrics = MetricSemSeg(
args.path_original,
args.path_predicted,
args.adjust_gt_data,
args.plot_confusion_matrix,
args.verbose
)
# get the metrics # get the metrics
results = metrics.main() results = metrics.main()
......
...@@ -80,7 +80,7 @@ def merge_ply_files(data_folder, output_file='output_instance_segmented.ply'): ...@@ -80,7 +80,7 @@ def merge_ply_files(data_folder, output_file='output_instance_segmented.ply'):
# print where the file is saved # print where the file is saved
logging.info("The file is saved in: " + os.path.join(data_folder, output_file)) logging.info("The file is saved in: " + os.path.join(data_folder, output_file))
logging.info("Merging was done for {} number of files".format(len(tags))) logging.info("Merging was done for {} files".format(len(tags)))
pipeline = pdal.Pipeline(json.dumps(data)) pipeline = pdal.Pipeline(json.dumps(data))
pipeline.execute() pipeline.execute()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment