From 9925c0b52d00f242dea53f774438b0731d89631f Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Sun, 8 Jan 2023 16:43:54 +0100 Subject: [PATCH] updated matrics for instance segmentation - austrian dataset --- .../instance_segmentation_metrics_austrian.py | 449 ++++++++++++++++++ 1 file changed, 449 insertions(+) create mode 100644 metrics/instance_segmentation_metrics_austrian.py diff --git a/metrics/instance_segmentation_metrics_austrian.py b/metrics/instance_segmentation_metrics_austrian.py new file mode 100644 index 0000000..a338ca0 --- /dev/null +++ b/metrics/instance_segmentation_metrics_austrian.py @@ -0,0 +1,449 @@ +import argparse +import os +import laspy +import logging +import numpy as np +import pandas as pd +from sklearn.neighbors import KDTree +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO) + +class InstanceSegmentationMetrics: + GT_LABEL_NAME = 'StemID' #GT_LABEL_NAME = 'StemID' + TARGET_LABEL_NAME = 'instance_nr' + def __init__( + self, + input_file_path, + instance_segmented_file_path, + remove_ground = False, + csv_file_name=None, + verbose=False + ) -> None: + + self.input_file_path = input_file_path + self.instance_segmented_file_path = instance_segmented_file_path + self.remove_ground = remove_ground + self.csv_file_name = csv_file_name + self.verbose = verbose + # read and prepare input las file and instance segmented las file + self.input_las = laspy.read(self.input_file_path) + self.instance_segmented_las = laspy.read(self.instance_segmented_file_path) + + self.skip_flag = self.check_if_labels_exist( + X_label=self.GT_LABEL_NAME, + Y_label=self.TARGET_LABEL_NAME + ) + + if not self.skip_flag: + # get labels from input las file + self.X_labels = self.input_las[self.GT_LABEL_NAME].astype(int) + # get labels from instance segmented las file + self.Y_labels = self.instance_segmented_las[self.TARGET_LABEL_NAME].astype(int) + # if self.remove_ground: + # # the labeling starts from 0, so we need to remove the ground + # self.Y_labels += 1 + + # do knn mapping + self.dict_Y = self.do_knn_mapping() + else: + logging.info('Skipping the file: {}'.format(self.input_file_path)) + + + def check_if_labels_exist(self, X_label='treeID', Y_label='instance_nr'): + # check if the labels exist in the las files + skip_flag = False + + if X_label not in self.input_las.header.point_format.dimension_names: + skip_flag = True + if Y_label not in self.instance_segmented_las.header.point_format.dimension_names: + skip_flag = True + + return skip_flag + + def do_knn_mapping(self): + X = self.input_las.xyz + Y = self.instance_segmented_las.xyz + X_labels = self.X_labels + Y_labels = self.Y_labels + + # create a KDTree for X + tree = KDTree(X, leaf_size=50, metric='euclidean') + # query the tree for Y + ind = tree.query(Y, k=1, return_distance=False) + + # get labels for ind + ind_labels_Y = X_labels[ind] + + # reshape to 1D + ind_labels_Y = ind_labels_Y.reshape(-1) # labels from X matched to Y (new gt labels) + + # get all the indices in X which were matched to Y + residual_ind = np.delete(np.arange(X.shape[0]), ind.reshape(-1)) # indices of X which were not matched to Y + + # create a dictionary which contains Y, Y_labels and ind_labels_Y + dict_Y = { + 'X': X, # X is the input las file + 'Y': Y, # Y is the instance segmented las file + 'Y_labels': Y_labels, # Y_labels is the instance segmented las file + 'ind_labels_Y': ind_labels_Y, # ind_labels_Y is the labels from X matched to Y (new gt labels) + 'ind': ind, # ind is the indices of X which were matched to Y + 'residual_ind': residual_ind # residual_ind is the indices of X which were not matched to Y + } + + return dict_Y + + def get_dominant_lables_sorted(self): + # get unique labels from Y_labels + Y_unique_labels = np.unique(self.Y_labels) + + dominant_labels = {} + for label in Y_unique_labels: + # get the indices of Y_labels == label + ind_Y_labels = np.where(self.Y_labels == label)[0] + # get the ind_labels_Y for these indices + ind_labels_Y = self.dict_Y['ind_labels_Y'][ind_Y_labels] + # get the unique ind_labels_Y + unique_ind_labels_Y = np.unique(ind_labels_Y) + # print the number of points for each unique ind_labels_Y + tmp_dict = {} + for unique_ind_label_Y in unique_ind_labels_Y: + # get the indices of ind_labels_Y == unique_ind_label_Y + ind_ind_labels_Y = np.where(ind_labels_Y == unique_ind_label_Y)[0] + # put the number of points to the tmp_dict + tmp_dict[str(unique_ind_label_Y)] = ind_ind_labels_Y.shape[0] + + # put the dominant label to the dominant_labels + dominant_labels[str(label)] = tmp_dict + + # sort dominant_labels by the number of points + dominant_labels_sorted = {} + for key, value in dominant_labels.items(): + dominant_labels_sorted[key] = {k: v for k, v in sorted(value.items(), key=lambda item: item[1], reverse=True)} + + # iterate over the dominant_labels_sorted and sort it based on the first value of sub-dictionary + dominant_labels_sorted = { + k: v for k, v in sorted(dominant_labels_sorted.items(), key=lambda item: list(item[1].values())[0], reverse=True)} + + return dominant_labels_sorted + + def extract_from_sub_dict(self, target_dict, label): + new_dict = {} + + for key_outer, value_outer in target_dict.items(): + tmp_dict = {} + + for item_inner in value_outer.keys(): + if item_inner == label: + tmp_dict[item_inner] = value_outer[item_inner] + new_dict[key_outer] = (tmp_dict) + return new_dict + + + # define a function that finds class in input_file with the most points + def find_dominant_classes_in_gt(self, input_file): + # get the unique labels + unique_labels = np.unique(input_file[self.GT_LABEL_NAME]).astype(int) + # create a dictionary + tmp_dict = {} + for label in unique_labels: + # get the indices of input_file.treeID == label + ind_label = np.where(input_file[self.GT_LABEL_NAME] == label)[0] + # put the number of points to the tmp_dict + tmp_dict[str(label)] = ind_label.shape[0] + # sort tmp_dict by the number of points + tmp_dict_sorted = {k: v for k, v in sorted(tmp_dict.items(), key=lambda item: item[1], reverse=True)} + + # remove key 0 from tmp_dict_sorted + if self.remove_ground: + tmp_dict_sorted.pop('0', None) + + return tmp_dict_sorted.keys() + + def get_the_dominant_label(self, dominant_labels_sorted): + # get the dominant label + # iterate over the dominant_labels_sorted and sort it based on the first value of sub-dictionary + # if sub-dictionary is empty, remove the key from the dictionary + + for key, value in dominant_labels_sorted.copy().items(): + if not value: + del dominant_labels_sorted[key] + + dominant_labels_sorted = { + k: v for k, v in sorted(dominant_labels_sorted.items(), key=lambda item: list(item[1].values())[0], reverse=True)} + + dominant_label_key = list(dominant_labels_sorted.keys())[0] + dominant_label_value = list(dominant_labels_sorted.values())[0] + dominant_label = list(dominant_label_value.keys())[0] + + # get dominant_label_key and dominant_label for which dominant_label_value.values() has the highest value + for key, value in dominant_labels_sorted.items(): + for item in value.keys(): + if value[item] > dominant_label_value[dominant_label]: + dominant_label_key = key + dominant_label = item + dominant_label_value = value + + return dominant_label_key, dominant_label + + + def remove_dominant_label(self, dominant_labels_sorted, dominant_label_key, dominant_label): + # remove the dominant_label_key from dominant_labels_sorted + dominant_labels_sorted.pop(dominant_label_key) + # remove the dominant_label from the sub-dictionary of dominant_labels_sorted + for key, value in dominant_labels_sorted.items(): + if dominant_label in value: + value.pop(dominant_label) + + return dominant_labels_sorted + + + def iterate_over_pc(self): + + label_mapping_dict = {} + + dominant_labels_sorted = self.get_dominant_lables_sorted() + gt_classes_to_iterate = self.find_dominant_classes_in_gt(self.input_las) + + for gt_class in gt_classes_to_iterate: + # if all the values in dominant_labels_sorted are empty, break the loop + + if self.remove_ground: + # check if all the sub-dictionaries have only one key and it is 0 + if all(len(v) == 1 and '0' in v for v in dominant_labels_sorted.values()): + break + + if not any(dominant_labels_sorted.values()): + break + + if len(dominant_labels_sorted) == 1: + dominant_label_key, dominant_label = self.get_the_dominant_label(dominant_labels_sorted) + label_mapping_dict[dominant_label_key] = dominant_label + break + + extracted = self.extract_from_sub_dict(dominant_labels_sorted, gt_class) + + # if extracted is empty, continue + if not extracted: + continue + + # if all the values in extracted are empty, continue + if not any(extracted.values()): + continue + + dominant_label_key, dominant_label = self.get_the_dominant_label(extracted) + + self.remove_dominant_label(dominant_labels_sorted, dominant_label_key, dominant_label) + + label_mapping_dict[dominant_label_key] = dominant_label + + # change keys and values to int + label_mapping_dict = {int(k): int(v) for k, v in label_mapping_dict.items()} + + return label_mapping_dict + + def compute_metrics(self): + # get the label_mapping_dict + metric_dict = {} + + if not self.skip_flag: + label_mapping_dict = self.iterate_over_pc() + + if self.verbose: + print('Computing metrics for individual trees...') + + for label in tqdm(list(label_mapping_dict.keys())): + # get the indices of Y_labels == label + ind_Y_labels_label = np.where(self.Y_labels == label)[0] # indices of Y_labels == label + + # get the X labels for these indices + ind_labels_Y = self.dict_Y['ind_labels_Y'][ind_Y_labels_label] # X labels for these indices + + # get the dominant label for this label + dominant_label = label_mapping_dict[label] + + # get the indices of ind_labels_Y == dominant_label + ind_dominant_label = np.where(ind_labels_Y == dominant_label)[0] + + ## true positive is the number of points for dominant_label + true_positive = ind_dominant_label.shape[0] + + ## points which are within the relabelled pred but are not dominant_label + false_positive = ind_Y_labels_label.shape[0] - true_positive + + ## false negative is the number of points which are not in Y but are in X + false_negative = np.where(self.X_labels[self.dict_Y['residual_ind']] == dominant_label)[0].shape[0] + + ## true negative + true_negative = self.dict_Y['X'].shape[0] - false_negative - true_positive - false_positive + + # sum all the true_positive, false_positive, false_negative, true_negative + sum_all = true_positive + false_positive + false_negative + true_negative + + # get precision + precision = true_positive / (true_positive + false_positive) + # get recall + recall = true_positive / (true_positive + false_negative) + # get f1 score + f1_score = 2 * (precision * recall) / (precision + recall) + # get IoU + IoU = true_positive / (true_positive + false_positive + false_negative) + + # find hight of the tree in the ground truth + hight_of_tree_gt = (self.input_las[self.X_labels == dominant_label].z).max() - (self.input_las[self.X_labels == dominant_label].z).min() + # find hight of the tree in the prediction + hight_of_tree_pred = (self.instance_segmented_las[self.Y_labels == label].z).max() - (self.instance_segmented_las[self.Y_labels == label].z).min() + + # get abs resiudal of the hight of the tree in the prediction + residual_hight_of_tree_pred = abs(hight_of_tree_gt - hight_of_tree_pred) + + # create tmp dict + tmp_dict = { + 'pred_label': label, + 'gt_label(dominant_label)': dominant_label, + 'high_of_tree_gt': hight_of_tree_gt, + 'high_of_tree_pred': hight_of_tree_pred, + 'residual_hight(gt_minus_pred)': residual_hight_of_tree_pred, + 'sum_all': sum_all, + 'true_positive': true_positive, + 'false_positive': false_positive, + 'false_negative': false_negative, + 'true_negative': true_negative, + 'precision': precision, + 'recall': recall, + 'f1_score': f1_score, + 'IoU': IoU, + } + metric_dict[str(label)] = tmp_dict + + # list of interesting metrics + interesting_parameters = ['precision', 'recall', 'f1_score', 'IoU', 'residual_hight(gt_minus_pred)'] + + # weight the metrics by tree hight + metric_dict_weighted_by_tree_hight = {} + # itialize the metric_dict_weighted_by_tree_hight + for parameter in interesting_parameters: + metric_dict_weighted_by_tree_hight[parameter] = 0 + + # do this if there is at least one label + if metric_dict: + for label in metric_dict.keys(): + print('label: ', label) + for parameter in interesting_parameters: + print('parameter: ', parameter) + print('metric_dict[label][parameter]: ', metric_dict[label][parameter]) + metric_dict_weighted_by_tree_hight[parameter] += metric_dict[label]['high_of_tree_gt'] * metric_dict[label][parameter] + # divide by the sum of the hights of the trees + for parameter in interesting_parameters: + metric_dict_weighted_by_tree_hight[parameter] /= sum([metric_dict[label]['high_of_tree_gt'] for label in metric_dict.keys()]) + + # compute the mean of the metrics + metric_dict_mean = {} + for parameter in interesting_parameters: + metric_dict_mean[parameter] = 0 + + # do this if there is at least one label + if metric_dict: + for key, value in metric_dict.items(): + for parameter in interesting_parameters: + metric_dict_mean[parameter] += value[parameter] + + for parameter in interesting_parameters: + metric_dict_mean[parameter] = metric_dict_mean[parameter] / len(metric_dict) + + # compute tree level metrics + if metric_dict: + # get the number of trees in the ground truth + gt_trees = np.unique(self.input_las[self.GT_LABEL_NAME]) + + # remove 0 from gt_trees + gt_trees = gt_trees[gt_trees != 0] + + # get the number of trees that are predicted correctly + trees_predicted = np.unique([metric_dict[key]['gt_label(dominant_label)'] for key in metric_dict.keys()]) + + # iterate over metric_dict and get the number of trees that are predicted correctly with IoU > 0.5 + trees_correctly_predicted_IoU = np.unique([metric_dict[key]['gt_label(dominant_label)'] for key in metric_dict.keys() if metric_dict[key]['IoU'] > 0.5]) + + # convert to set + gt_trees = set(gt_trees) + trees_predicted = set(trees_predicted) + trees_correctly_predicted_IoU = set(trees_correctly_predicted_IoU) + + tree_level_metric = { + 'true_positve (detection rate)': len(trees_correctly_predicted_IoU) / len(gt_trees), + 'false_positve (commission)': len(trees_predicted - trees_correctly_predicted_IoU) / len(gt_trees), + 'false_negative (omissions)': len(gt_trees - trees_predicted - trees_correctly_predicted_IoU) / len(gt_trees), + 'gt': len(gt_trees)} + + # add tree level metrics to the metric_dict_mean + metric_dict_mean.update(tree_level_metric) + + if self.verbose: + print('Tree level metrics:') + print(f'Trees in the ground truth: {gt_trees}') + print(f'Trees correctly predicted: {trees_predicted}') + print(f'Trees correctly predicted with IoU > 0.5: {trees_correctly_predicted_IoU}') + + print(tree_level_metric) + + + + return metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean + + def print_metrics(self, metric_dict): + for key, value in metric_dict.items(): + print(f'Label: {key}') + for key2, value2 in value.items(): + print(f'{key2}: {value2}') + print('') + + def save_to_csv_file(self, metric_dict): + df = pd.DataFrame(metric_dict).T + # save to csv file and show the index + df.to_csv(self.csv_file_name, index=True, header=True) + + + def main(self): + metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean = self.compute_metrics() + + if self.verbose: + f1_weighted_by_tree_hight = metric_dict_weighted_by_tree_hight['f1_score'] + print(f'f1_score_weighted: {f1_weighted_by_tree_hight}') + for key, value in metric_dict.items(): + print(f'Label: {key}') + for key2, value2 in value.items(): + print(f'{key2}: {value2}') + print('') + + if self.csv_file_name is not None: + self.save_to_csv_file(metric_dict) + if self.verbose: + print(f'Metrics saved to {self.csv_file_name}') + + return metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean + + +# main +if __name__ == '__main__': + # do argparse input_file_path, instance_segmented_file_path, verbose + parser = argparse.ArgumentParser() + parser.add_argument('--input_file_path', type=str, required=True) + parser.add_argument('--instance_segmented_file_path', type=str, required=True) + parser.add_argument('--remove_ground', action='store_true', help="Do not take into account the ground (class 0).", default=False) + parser.add_argument('--csv_file_name', type=str, help="Name of the csv file to save the metrics to", default=None) + parser.add_argument('--verbose', action='store_true', help="Print information about the process", default=False) + + args = parser.parse_args() + + # create instance of the class InstanceSegmentationMetrics + instance_segmentation_metrics = InstanceSegmentationMetrics( + args.input_file_path, + args.instance_segmented_file_path, + args.remove_ground, + args.csv_file_name, + args.verbose + ) + + # compute metrics + metric_dict, _, _ = instance_segmentation_metrics.main() \ No newline at end of file -- GitLab