Skip to content
Snippets Groups Projects
Commit 9925c0b5 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

updated matrics for instance segmentation - austrian dataset

parent 781bd59b
Branches
No related tags found
No related merge requests found
import argparse
import os
import laspy
import logging
import numpy as np
import pandas as pd
from sklearn.neighbors import KDTree
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
class InstanceSegmentationMetrics:
GT_LABEL_NAME = 'StemID' #GT_LABEL_NAME = 'StemID'
TARGET_LABEL_NAME = 'instance_nr'
def __init__(
self,
input_file_path,
instance_segmented_file_path,
remove_ground = False,
csv_file_name=None,
verbose=False
) -> None:
self.input_file_path = input_file_path
self.instance_segmented_file_path = instance_segmented_file_path
self.remove_ground = remove_ground
self.csv_file_name = csv_file_name
self.verbose = verbose
# read and prepare input las file and instance segmented las file
self.input_las = laspy.read(self.input_file_path)
self.instance_segmented_las = laspy.read(self.instance_segmented_file_path)
self.skip_flag = self.check_if_labels_exist(
X_label=self.GT_LABEL_NAME,
Y_label=self.TARGET_LABEL_NAME
)
if not self.skip_flag:
# get labels from input las file
self.X_labels = self.input_las[self.GT_LABEL_NAME].astype(int)
# get labels from instance segmented las file
self.Y_labels = self.instance_segmented_las[self.TARGET_LABEL_NAME].astype(int)
# if self.remove_ground:
# # the labeling starts from 0, so we need to remove the ground
# self.Y_labels += 1
# do knn mapping
self.dict_Y = self.do_knn_mapping()
else:
logging.info('Skipping the file: {}'.format(self.input_file_path))
def check_if_labels_exist(self, X_label='treeID', Y_label='instance_nr'):
# check if the labels exist in the las files
skip_flag = False
if X_label not in self.input_las.header.point_format.dimension_names:
skip_flag = True
if Y_label not in self.instance_segmented_las.header.point_format.dimension_names:
skip_flag = True
return skip_flag
def do_knn_mapping(self):
X = self.input_las.xyz
Y = self.instance_segmented_las.xyz
X_labels = self.X_labels
Y_labels = self.Y_labels
# create a KDTree for X
tree = KDTree(X, leaf_size=50, metric='euclidean')
# query the tree for Y
ind = tree.query(Y, k=1, return_distance=False)
# get labels for ind
ind_labels_Y = X_labels[ind]
# reshape to 1D
ind_labels_Y = ind_labels_Y.reshape(-1) # labels from X matched to Y (new gt labels)
# get all the indices in X which were matched to Y
residual_ind = np.delete(np.arange(X.shape[0]), ind.reshape(-1)) # indices of X which were not matched to Y
# create a dictionary which contains Y, Y_labels and ind_labels_Y
dict_Y = {
'X': X, # X is the input las file
'Y': Y, # Y is the instance segmented las file
'Y_labels': Y_labels, # Y_labels is the instance segmented las file
'ind_labels_Y': ind_labels_Y, # ind_labels_Y is the labels from X matched to Y (new gt labels)
'ind': ind, # ind is the indices of X which were matched to Y
'residual_ind': residual_ind # residual_ind is the indices of X which were not matched to Y
}
return dict_Y
def get_dominant_lables_sorted(self):
# get unique labels from Y_labels
Y_unique_labels = np.unique(self.Y_labels)
dominant_labels = {}
for label in Y_unique_labels:
# get the indices of Y_labels == label
ind_Y_labels = np.where(self.Y_labels == label)[0]
# get the ind_labels_Y for these indices
ind_labels_Y = self.dict_Y['ind_labels_Y'][ind_Y_labels]
# get the unique ind_labels_Y
unique_ind_labels_Y = np.unique(ind_labels_Y)
# print the number of points for each unique ind_labels_Y
tmp_dict = {}
for unique_ind_label_Y in unique_ind_labels_Y:
# get the indices of ind_labels_Y == unique_ind_label_Y
ind_ind_labels_Y = np.where(ind_labels_Y == unique_ind_label_Y)[0]
# put the number of points to the tmp_dict
tmp_dict[str(unique_ind_label_Y)] = ind_ind_labels_Y.shape[0]
# put the dominant label to the dominant_labels
dominant_labels[str(label)] = tmp_dict
# sort dominant_labels by the number of points
dominant_labels_sorted = {}
for key, value in dominant_labels.items():
dominant_labels_sorted[key] = {k: v for k, v in sorted(value.items(), key=lambda item: item[1], reverse=True)}
# iterate over the dominant_labels_sorted and sort it based on the first value of sub-dictionary
dominant_labels_sorted = {
k: v for k, v in sorted(dominant_labels_sorted.items(), key=lambda item: list(item[1].values())[0], reverse=True)}
return dominant_labels_sorted
def extract_from_sub_dict(self, target_dict, label):
new_dict = {}
for key_outer, value_outer in target_dict.items():
tmp_dict = {}
for item_inner in value_outer.keys():
if item_inner == label:
tmp_dict[item_inner] = value_outer[item_inner]
new_dict[key_outer] = (tmp_dict)
return new_dict
# define a function that finds class in input_file with the most points
def find_dominant_classes_in_gt(self, input_file):
# get the unique labels
unique_labels = np.unique(input_file[self.GT_LABEL_NAME]).astype(int)
# create a dictionary
tmp_dict = {}
for label in unique_labels:
# get the indices of input_file.treeID == label
ind_label = np.where(input_file[self.GT_LABEL_NAME] == label)[0]
# put the number of points to the tmp_dict
tmp_dict[str(label)] = ind_label.shape[0]
# sort tmp_dict by the number of points
tmp_dict_sorted = {k: v for k, v in sorted(tmp_dict.items(), key=lambda item: item[1], reverse=True)}
# remove key 0 from tmp_dict_sorted
if self.remove_ground:
tmp_dict_sorted.pop('0', None)
return tmp_dict_sorted.keys()
def get_the_dominant_label(self, dominant_labels_sorted):
# get the dominant label
# iterate over the dominant_labels_sorted and sort it based on the first value of sub-dictionary
# if sub-dictionary is empty, remove the key from the dictionary
for key, value in dominant_labels_sorted.copy().items():
if not value:
del dominant_labels_sorted[key]
dominant_labels_sorted = {
k: v for k, v in sorted(dominant_labels_sorted.items(), key=lambda item: list(item[1].values())[0], reverse=True)}
dominant_label_key = list(dominant_labels_sorted.keys())[0]
dominant_label_value = list(dominant_labels_sorted.values())[0]
dominant_label = list(dominant_label_value.keys())[0]
# get dominant_label_key and dominant_label for which dominant_label_value.values() has the highest value
for key, value in dominant_labels_sorted.items():
for item in value.keys():
if value[item] > dominant_label_value[dominant_label]:
dominant_label_key = key
dominant_label = item
dominant_label_value = value
return dominant_label_key, dominant_label
def remove_dominant_label(self, dominant_labels_sorted, dominant_label_key, dominant_label):
# remove the dominant_label_key from dominant_labels_sorted
dominant_labels_sorted.pop(dominant_label_key)
# remove the dominant_label from the sub-dictionary of dominant_labels_sorted
for key, value in dominant_labels_sorted.items():
if dominant_label in value:
value.pop(dominant_label)
return dominant_labels_sorted
def iterate_over_pc(self):
label_mapping_dict = {}
dominant_labels_sorted = self.get_dominant_lables_sorted()
gt_classes_to_iterate = self.find_dominant_classes_in_gt(self.input_las)
for gt_class in gt_classes_to_iterate:
# if all the values in dominant_labels_sorted are empty, break the loop
if self.remove_ground:
# check if all the sub-dictionaries have only one key and it is 0
if all(len(v) == 1 and '0' in v for v in dominant_labels_sorted.values()):
break
if not any(dominant_labels_sorted.values()):
break
if len(dominant_labels_sorted) == 1:
dominant_label_key, dominant_label = self.get_the_dominant_label(dominant_labels_sorted)
label_mapping_dict[dominant_label_key] = dominant_label
break
extracted = self.extract_from_sub_dict(dominant_labels_sorted, gt_class)
# if extracted is empty, continue
if not extracted:
continue
# if all the values in extracted are empty, continue
if not any(extracted.values()):
continue
dominant_label_key, dominant_label = self.get_the_dominant_label(extracted)
self.remove_dominant_label(dominant_labels_sorted, dominant_label_key, dominant_label)
label_mapping_dict[dominant_label_key] = dominant_label
# change keys and values to int
label_mapping_dict = {int(k): int(v) for k, v in label_mapping_dict.items()}
return label_mapping_dict
def compute_metrics(self):
# get the label_mapping_dict
metric_dict = {}
if not self.skip_flag:
label_mapping_dict = self.iterate_over_pc()
if self.verbose:
print('Computing metrics for individual trees...')
for label in tqdm(list(label_mapping_dict.keys())):
# get the indices of Y_labels == label
ind_Y_labels_label = np.where(self.Y_labels == label)[0] # indices of Y_labels == label
# get the X labels for these indices
ind_labels_Y = self.dict_Y['ind_labels_Y'][ind_Y_labels_label] # X labels for these indices
# get the dominant label for this label
dominant_label = label_mapping_dict[label]
# get the indices of ind_labels_Y == dominant_label
ind_dominant_label = np.where(ind_labels_Y == dominant_label)[0]
## true positive is the number of points for dominant_label
true_positive = ind_dominant_label.shape[0]
## points which are within the relabelled pred but are not dominant_label
false_positive = ind_Y_labels_label.shape[0] - true_positive
## false negative is the number of points which are not in Y but are in X
false_negative = np.where(self.X_labels[self.dict_Y['residual_ind']] == dominant_label)[0].shape[0]
## true negative
true_negative = self.dict_Y['X'].shape[0] - false_negative - true_positive - false_positive
# sum all the true_positive, false_positive, false_negative, true_negative
sum_all = true_positive + false_positive + false_negative + true_negative
# get precision
precision = true_positive / (true_positive + false_positive)
# get recall
recall = true_positive / (true_positive + false_negative)
# get f1 score
f1_score = 2 * (precision * recall) / (precision + recall)
# get IoU
IoU = true_positive / (true_positive + false_positive + false_negative)
# find hight of the tree in the ground truth
hight_of_tree_gt = (self.input_las[self.X_labels == dominant_label].z).max() - (self.input_las[self.X_labels == dominant_label].z).min()
# find hight of the tree in the prediction
hight_of_tree_pred = (self.instance_segmented_las[self.Y_labels == label].z).max() - (self.instance_segmented_las[self.Y_labels == label].z).min()
# get abs resiudal of the hight of the tree in the prediction
residual_hight_of_tree_pred = abs(hight_of_tree_gt - hight_of_tree_pred)
# create tmp dict
tmp_dict = {
'pred_label': label,
'gt_label(dominant_label)': dominant_label,
'high_of_tree_gt': hight_of_tree_gt,
'high_of_tree_pred': hight_of_tree_pred,
'residual_hight(gt_minus_pred)': residual_hight_of_tree_pred,
'sum_all': sum_all,
'true_positive': true_positive,
'false_positive': false_positive,
'false_negative': false_negative,
'true_negative': true_negative,
'precision': precision,
'recall': recall,
'f1_score': f1_score,
'IoU': IoU,
}
metric_dict[str(label)] = tmp_dict
# list of interesting metrics
interesting_parameters = ['precision', 'recall', 'f1_score', 'IoU', 'residual_hight(gt_minus_pred)']
# weight the metrics by tree hight
metric_dict_weighted_by_tree_hight = {}
# itialize the metric_dict_weighted_by_tree_hight
for parameter in interesting_parameters:
metric_dict_weighted_by_tree_hight[parameter] = 0
# do this if there is at least one label
if metric_dict:
for label in metric_dict.keys():
print('label: ', label)
for parameter in interesting_parameters:
print('parameter: ', parameter)
print('metric_dict[label][parameter]: ', metric_dict[label][parameter])
metric_dict_weighted_by_tree_hight[parameter] += metric_dict[label]['high_of_tree_gt'] * metric_dict[label][parameter]
# divide by the sum of the hights of the trees
for parameter in interesting_parameters:
metric_dict_weighted_by_tree_hight[parameter] /= sum([metric_dict[label]['high_of_tree_gt'] for label in metric_dict.keys()])
# compute the mean of the metrics
metric_dict_mean = {}
for parameter in interesting_parameters:
metric_dict_mean[parameter] = 0
# do this if there is at least one label
if metric_dict:
for key, value in metric_dict.items():
for parameter in interesting_parameters:
metric_dict_mean[parameter] += value[parameter]
for parameter in interesting_parameters:
metric_dict_mean[parameter] = metric_dict_mean[parameter] / len(metric_dict)
# compute tree level metrics
if metric_dict:
# get the number of trees in the ground truth
gt_trees = np.unique(self.input_las[self.GT_LABEL_NAME])
# remove 0 from gt_trees
gt_trees = gt_trees[gt_trees != 0]
# get the number of trees that are predicted correctly
trees_predicted = np.unique([metric_dict[key]['gt_label(dominant_label)'] for key in metric_dict.keys()])
# iterate over metric_dict and get the number of trees that are predicted correctly with IoU > 0.5
trees_correctly_predicted_IoU = np.unique([metric_dict[key]['gt_label(dominant_label)'] for key in metric_dict.keys() if metric_dict[key]['IoU'] > 0.5])
# convert to set
gt_trees = set(gt_trees)
trees_predicted = set(trees_predicted)
trees_correctly_predicted_IoU = set(trees_correctly_predicted_IoU)
tree_level_metric = {
'true_positve (detection rate)': len(trees_correctly_predicted_IoU) / len(gt_trees),
'false_positve (commission)': len(trees_predicted - trees_correctly_predicted_IoU) / len(gt_trees),
'false_negative (omissions)': len(gt_trees - trees_predicted - trees_correctly_predicted_IoU) / len(gt_trees),
'gt': len(gt_trees)}
# add tree level metrics to the metric_dict_mean
metric_dict_mean.update(tree_level_metric)
if self.verbose:
print('Tree level metrics:')
print(f'Trees in the ground truth: {gt_trees}')
print(f'Trees correctly predicted: {trees_predicted}')
print(f'Trees correctly predicted with IoU > 0.5: {trees_correctly_predicted_IoU}')
print(tree_level_metric)
return metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean
def print_metrics(self, metric_dict):
for key, value in metric_dict.items():
print(f'Label: {key}')
for key2, value2 in value.items():
print(f'{key2}: {value2}')
print('')
def save_to_csv_file(self, metric_dict):
df = pd.DataFrame(metric_dict).T
# save to csv file and show the index
df.to_csv(self.csv_file_name, index=True, header=True)
def main(self):
metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean = self.compute_metrics()
if self.verbose:
f1_weighted_by_tree_hight = metric_dict_weighted_by_tree_hight['f1_score']
print(f'f1_score_weighted: {f1_weighted_by_tree_hight}')
for key, value in metric_dict.items():
print(f'Label: {key}')
for key2, value2 in value.items():
print(f'{key2}: {value2}')
print('')
if self.csv_file_name is not None:
self.save_to_csv_file(metric_dict)
if self.verbose:
print(f'Metrics saved to {self.csv_file_name}')
return metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean
# main
if __name__ == '__main__':
# do argparse input_file_path, instance_segmented_file_path, verbose
parser = argparse.ArgumentParser()
parser.add_argument('--input_file_path', type=str, required=True)
parser.add_argument('--instance_segmented_file_path', type=str, required=True)
parser.add_argument('--remove_ground', action='store_true', help="Do not take into account the ground (class 0).", default=False)
parser.add_argument('--csv_file_name', type=str, help="Name of the csv file to save the metrics to", default=None)
parser.add_argument('--verbose', action='store_true', help="Print information about the process", default=False)
args = parser.parse_args()
# create instance of the class InstanceSegmentationMetrics
instance_segmentation_metrics = InstanceSegmentationMetrics(
args.input_file_path,
args.instance_segmented_file_path,
args.remove_ground,
args.csv_file_name,
args.verbose
)
# compute metrics
metric_dict, _, _ = instance_segmentation_metrics.main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment