From 79605dff3850bb58b2c55d31e96deb1987e16b14 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Tue, 3 Jan 2023 12:40:26 +0100 Subject: [PATCH] instance segmentation metrics visualizer --- ...instance_segmentation_metrics_in_folder.py | 2 +- visualization/inst_seg_visualizer.py | 122 ++++++++++++++++++ 2 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 visualization/inst_seg_visualizer.py diff --git a/metrics/instance_segmentation_metrics_in_folder.py b/metrics/instance_segmentation_metrics_in_folder.py index 9bc000a..8ddd37b 100644 --- a/metrics/instance_segmentation_metrics_in_folder.py +++ b/metrics/instance_segmentation_metrics_in_folder.py @@ -160,7 +160,7 @@ class InstanceSegmentationMetricsInFolder(): AttachLabelsToLasFile( gt_las_file_path, target_las_file_path, - update_las_file_path = os.path.join(self.output_folder_path, gt_las_file_core_name + '_updated.las'), + update_las_file_path = os.path.join(self.output_folder_path, gt_las_file_core_name + '.las'), gt_label_name='treeID', target_label_name='treeID', verbose=self.verbose diff --git a/visualization/inst_seg_visualizer.py b/visualization/inst_seg_visualizer.py new file mode 100644 index 0000000..cdf6fa3 --- /dev/null +++ b/visualization/inst_seg_visualizer.py @@ -0,0 +1,122 @@ +import argparse +import glob +import os +import laspy +import numpy as np +import pandas as pd +from tqdm import tqdm + +class InstSegVisualizer: + GT_LABEL_NAME = 'treeID' + TARGET_LABEL_NAME = 'instance_nr' + CSV_GT_LABEL_NAME = 'gt_label(dominant_label)' + CSV_TARGET_LABEL_NAME = 'pred_label' + + def __init__(self, folder_with_metrics, verbose) -> None: + self.folder_with_metrics = folder_with_metrics + self.verbose = verbose + + def get_las_file_paths(self, folder_path): + las_file_paths = glob.glob(folder_path + '/*.las', recursive=False) + las_file_paths.sort() + return las_file_paths + + def get_csv_file_paths(self, folder_path): + csv_file_paths = glob.glob(folder_path + '/*.csv', recursive=False) + csv_file_paths.sort() + return csv_file_paths + + def match_las_and_csv_files(self, las_file_paths, csv_file_paths): + matched_paths = [] + for las_file_path, csv_file_path in zip(las_file_paths, csv_file_paths): + las_file_name = las_file_path.split('/')[-1] + csv_file_name = csv_file_path.split('/')[-1] + if las_file_name.split('.')[0] == csv_file_name.split('.')[0]: + matched_paths.append((las_file_path, csv_file_path)) + else: + raise Exception('The las file name and the csv file name do not match') + return matched_paths + + def get_gt_and_pred_labels_from_csv_for_single_file(self, csv_file_path): + df = pd.read_csv(csv_file_path) + gt_labels = df[self.CSV_GT_LABEL_NAME].values + pred_labels = df[self.CSV_TARGET_LABEL_NAME].values + # map to int + gt_labels = gt_labels.astype(int) + pred_labels = pred_labels.astype(int) + # create a list of tuples + matched_labels = list(zip(gt_labels, pred_labels)) + return matched_labels + + def extract_overlapping_pc_in_single_file(self, las_file_path, matched_labels): + # get the las file name + las_file_name = las_file_path.split('/')[-1] + + # create if does not exist the folder with metrics named after the las file + if not os.path.exists(self.folder_with_metrics + '/' + las_file_name.split('.')[0]): + os.makedirs(self.folder_with_metrics + '/' + las_file_name.split('.')[0]) + + for gt_label, pred_label in tqdm(matched_labels): + # read the las file + las_file = laspy.read(las_file_path) + + # get GT_LABEL_NAME points + gt_points = las_file.points[las_file.points[self.GT_LABEL_NAME] == gt_label] + # create a new array of gt points with the same shape as the red array and fill it with 255 + gt_points['red'] = np.ones(gt_points['red'].shape) * 255 + gt_points['green'] = np.ones(gt_points['green'].shape) * 0 + gt_points['blue'] = np.ones(gt_points['blue'].shape) * 0 + + # get TARGET_LABEL_NAME points + pred_points = las_file.points[las_file.points[self.TARGET_LABEL_NAME] == pred_label] + + # create a new array of pred points with the same shape as the blue array and fill it with 255 + pred_points['red'] = np.ones(pred_points['red'].shape) * 0 + pred_points['green'] = np.ones(pred_points['green'].shape) * 0 + pred_points['blue'] = np.ones(pred_points['blue'].shape) * 255 + + # dump both arrays to las to the folder with metrics named after the las file + las_file.points = gt_points + las_file.write(self.folder_with_metrics + '/' + las_file_name.split('.')[0] + '/' + las_file_name.split('.')[0] + '_gt_' + str(gt_label) + '.las') + las_file.points = pred_points + las_file.write(self.folder_with_metrics + '/' + las_file_name.split('.')[0] + '/' + las_file_name.split('.')[0] + '_pred_' + str(pred_label) + '.las') + + # merge using pdal and os.system and save the merged file in the folder with metrics named after the las file + os.system( + 'pdal merge ' + self.folder_with_metrics + '/' + + las_file_name.split('.')[0] + '/' + las_file_name.split('.')[0] + + '_gt_' + str(gt_label) + '.las ' + self.folder_with_metrics + '/' + + las_file_name.split('.')[0] + '/' + las_file_name.split('.')[0] + + '_pred_' + str(pred_label) + '.las ' + self.folder_with_metrics + '/' + + las_file_name.split('.')[0] + '/' + las_file_name.split('.')[0] + + '_gt_' + str(gt_label) + '_pred_' + str(pred_label) + '.las' + ) + + def extract_overlapping_pc_of_all_files(self, matched_files): + for las_file_path, csv_file_path in matched_files: + if self.verbose: + print('Processing file: ', las_file_path) + matched_labels = self.get_gt_and_pred_labels_from_csv_for_single_file(csv_file_path) + self.extract_overlapping_pc_in_single_file(las_file_path, matched_labels) + + def main(self): + las_file_paths = self.get_las_file_paths(self.folder_with_metrics) + csv_file_paths = self.get_csv_file_paths(self.folder_with_metrics) + matched_files = self.match_las_and_csv_files(las_file_paths, csv_file_paths) + self.extract_overlapping_pc_of_all_files(matched_files) + + if self.verbose: + # print the number of las files + print('Number of las files: ', len(las_file_paths)) + # print where the files were saved + print('The files were saved in: ', self.folder_with_metrics) + +if __name__ == '__main__': + # use argparse to get the folder with metrics + parser = argparse.ArgumentParser() + parser.add_argument('--folder_with_metrics', type=str, default='metrics') + parser.add_argument('--verbose', help="Print more information.", action="store_true") + args = parser.parse_args() + + InstSegVisualizer(args.folder_with_metrics, args.verbose).main() + -- GitLab