-
Maciej Wielgosz authoredMaciej Wielgosz authored
instance_segmentation_metrics_austrian.py 18.89 KiB
import argparse
import os
import laspy
import logging
import numpy as np
import pandas as pd
from sklearn.neighbors import KDTree
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
class InstanceSegmentationMetrics:
GT_LABEL_NAME = 'StemID' #GT_LABEL_NAME = 'StemID'
TARGET_LABEL_NAME = 'instance_nr'
def __init__(
self,
input_file_path,
instance_segmented_file_path,
remove_ground = False,
csv_file_name=None,
verbose=False
) -> None:
self.input_file_path = input_file_path
self.instance_segmented_file_path = instance_segmented_file_path
self.remove_ground = remove_ground
self.csv_file_name = csv_file_name
self.verbose = verbose
# read and prepare input las file and instance segmented las file
self.input_las = laspy.read(self.input_file_path)
self.instance_segmented_las = laspy.read(self.instance_segmented_file_path)
self.skip_flag = self.check_if_labels_exist(
X_label=self.GT_LABEL_NAME,
Y_label=self.TARGET_LABEL_NAME
)
if not self.skip_flag:
# get labels from input las file
self.X_labels = self.input_las[self.GT_LABEL_NAME].astype(int)
# get labels from instance segmented las file
self.Y_labels = self.instance_segmented_las[self.TARGET_LABEL_NAME].astype(int)
# if self.remove_ground:
# # the labeling starts from 0, so we need to remove the ground
# self.Y_labels += 1
# do knn mapping
self.dict_Y = self.do_knn_mapping()
else:
logging.info('Skipping the file: {}'.format(self.input_file_path))
def check_if_labels_exist(self, X_label='treeID', Y_label='instance_nr'):
# check if the labels exist in the las files
skip_flag = False
if X_label not in self.input_las.header.point_format.dimension_names:
skip_flag = True
if Y_label not in self.instance_segmented_las.header.point_format.dimension_names:
skip_flag = True
return skip_flag
def do_knn_mapping(self):
X = self.input_las.xyz
Y = self.instance_segmented_las.xyz
X_labels = self.X_labels
Y_labels = self.Y_labels
# create a KDTree for X
tree = KDTree(X, leaf_size=50, metric='euclidean')
# query the tree for Y
ind = tree.query(Y, k=1, return_distance=False)
# get labels for ind
ind_labels_Y = X_labels[ind]
# reshape to 1D
ind_labels_Y = ind_labels_Y.reshape(-1) # labels from X matched to Y (new gt labels)
# get all the indices in X which were matched to Y
residual_ind = np.delete(np.arange(X.shape[0]), ind.reshape(-1)) # indices of X which were not matched to Y
# create a dictionary which contains Y, Y_labels and ind_labels_Y
dict_Y = {
'X': X, # X is the input las file
'Y': Y, # Y is the instance segmented las file
'Y_labels': Y_labels, # Y_labels is the instance segmented las file
'ind_labels_Y': ind_labels_Y, # ind_labels_Y is the labels from X matched to Y (new gt labels)
'ind': ind, # ind is the indices of X which were matched to Y
'residual_ind': residual_ind # residual_ind is the indices of X which were not matched to Y
}
return dict_Y
def get_dominant_lables_sorted(self):
# get unique labels from Y_labels
Y_unique_labels = np.unique(self.Y_labels)
dominant_labels = {}
for label in Y_unique_labels:
# get the indices of Y_labels == label
ind_Y_labels = np.where(self.Y_labels == label)[0]
# get the ind_labels_Y for these indices
ind_labels_Y = self.dict_Y['ind_labels_Y'][ind_Y_labels]
# get the unique ind_labels_Y
unique_ind_labels_Y = np.unique(ind_labels_Y)
# print the number of points for each unique ind_labels_Y
tmp_dict = {}
for unique_ind_label_Y in unique_ind_labels_Y:
# get the indices of ind_labels_Y == unique_ind_label_Y
ind_ind_labels_Y = np.where(ind_labels_Y == unique_ind_label_Y)[0]
# put the number of points to the tmp_dict
tmp_dict[str(unique_ind_label_Y)] = ind_ind_labels_Y.shape[0]
# put the dominant label to the dominant_labels
dominant_labels[str(label)] = tmp_dict
# sort dominant_labels by the number of points
dominant_labels_sorted = {}
for key, value in dominant_labels.items():
dominant_labels_sorted[key] = {k: v for k, v in sorted(value.items(), key=lambda item: item[1], reverse=True)}
# iterate over the dominant_labels_sorted and sort it based on the first value of sub-dictionary
dominant_labels_sorted = {
k: v for k, v in sorted(dominant_labels_sorted.items(), key=lambda item: list(item[1].values())[0], reverse=True)}
return dominant_labels_sorted
def extract_from_sub_dict(self, target_dict, label):
new_dict = {}
for key_outer, value_outer in target_dict.items():
tmp_dict = {}
for item_inner in value_outer.keys():
if item_inner == label:
tmp_dict[item_inner] = value_outer[item_inner]
new_dict[key_outer] = (tmp_dict)
return new_dict
# define a function that finds class in input_file with the most points
def find_dominant_classes_in_gt(self, input_file):
# get the unique labels
unique_labels = np.unique(input_file[self.GT_LABEL_NAME]).astype(int)
# create a dictionary
tmp_dict = {}
for label in unique_labels:
# get the indices of input_file.treeID == label
ind_label = np.where(input_file[self.GT_LABEL_NAME] == label)[0]
# put the number of points to the tmp_dict
tmp_dict[str(label)] = ind_label.shape[0]
# sort tmp_dict by the number of points
tmp_dict_sorted = {k: v for k, v in sorted(tmp_dict.items(), key=lambda item: item[1], reverse=True)}
# remove key 0 from tmp_dict_sorted
if self.remove_ground:
tmp_dict_sorted.pop('0', None)
return tmp_dict_sorted.keys()
def get_the_dominant_label(self, dominant_labels_sorted):
# get the dominant label
# iterate over the dominant_labels_sorted and sort it based on the first value of sub-dictionary
# if sub-dictionary is empty, remove the key from the dictionary
for key, value in dominant_labels_sorted.copy().items():
if not value:
del dominant_labels_sorted[key]
dominant_labels_sorted = {
k: v for k, v in sorted(dominant_labels_sorted.items(), key=lambda item: list(item[1].values())[0], reverse=True)}
dominant_label_key = list(dominant_labels_sorted.keys())[0]
dominant_label_value = list(dominant_labels_sorted.values())[0]
dominant_label = list(dominant_label_value.keys())[0]
# get dominant_label_key and dominant_label for which dominant_label_value.values() has the highest value
for key, value in dominant_labels_sorted.items():
for item in value.keys():
if value[item] > dominant_label_value[dominant_label]:
dominant_label_key = key
dominant_label = item
dominant_label_value = value
return dominant_label_key, dominant_label
def remove_dominant_label(self, dominant_labels_sorted, dominant_label_key, dominant_label):
# remove the dominant_label_key from dominant_labels_sorted
dominant_labels_sorted.pop(dominant_label_key)
# remove the dominant_label from the sub-dictionary of dominant_labels_sorted
for key, value in dominant_labels_sorted.items():
if dominant_label in value:
value.pop(dominant_label)
return dominant_labels_sorted
def iterate_over_pc(self):
label_mapping_dict = {}
dominant_labels_sorted = self.get_dominant_lables_sorted()
gt_classes_to_iterate = self.find_dominant_classes_in_gt(self.input_las)
for gt_class in gt_classes_to_iterate:
# if all the values in dominant_labels_sorted are empty, break the loop
if self.remove_ground:
# check if all the sub-dictionaries have only one key and it is 0
if all(len(v) == 1 and '0' in v for v in dominant_labels_sorted.values()):
break
if not any(dominant_labels_sorted.values()):
break
if len(dominant_labels_sorted) == 1:
dominant_label_key, dominant_label = self.get_the_dominant_label(dominant_labels_sorted)
label_mapping_dict[dominant_label_key] = dominant_label
break
extracted = self.extract_from_sub_dict(dominant_labels_sorted, gt_class)
# if extracted is empty, continue
if not extracted:
continue
# if all the values in extracted are empty, continue
if not any(extracted.values()):
continue
dominant_label_key, dominant_label = self.get_the_dominant_label(extracted)
self.remove_dominant_label(dominant_labels_sorted, dominant_label_key, dominant_label)
label_mapping_dict[dominant_label_key] = dominant_label
# change keys and values to int
label_mapping_dict = {int(k): int(v) for k, v in label_mapping_dict.items()}
return label_mapping_dict
def compute_metrics(self):
# get the label_mapping_dict
metric_dict = {}
if not self.skip_flag:
label_mapping_dict = self.iterate_over_pc()
if self.verbose:
print('Computing metrics for individual trees...')
for label in tqdm(list(label_mapping_dict.keys())):
# get the indices of Y_labels == label
ind_Y_labels_label = np.where(self.Y_labels == label)[0] # indices of Y_labels == label
# get the X labels for these indices
ind_labels_Y = self.dict_Y['ind_labels_Y'][ind_Y_labels_label] # X labels for these indices
# get the dominant label for this label
dominant_label = label_mapping_dict[label]
# get the indices of ind_labels_Y == dominant_label
ind_dominant_label = np.where(ind_labels_Y == dominant_label)[0]
## true positive is the number of points for dominant_label
true_positive = ind_dominant_label.shape[0]
## points which are within the relabelled pred but are not dominant_label
false_positive = ind_Y_labels_label.shape[0] - true_positive
## false negative is the number of points which are not in Y but are in X
false_negative = np.where(self.X_labels[self.dict_Y['residual_ind']] == dominant_label)[0].shape[0]
## true negative
true_negative = self.dict_Y['X'].shape[0] - false_negative - true_positive - false_positive
# sum all the true_positive, false_positive, false_negative, true_negative
sum_all = true_positive + false_positive + false_negative + true_negative
# get precision
precision = true_positive / (true_positive + false_positive)
# get recall
recall = true_positive / (true_positive + false_negative)
# get f1 score
f1_score = 2 * (precision * recall) / (precision + recall)
# get IoU
IoU = true_positive / (true_positive + false_positive + false_negative)
# find hight of the tree in the ground truth
hight_of_tree_gt = (self.input_las[self.X_labels == dominant_label].z).max() - (self.input_las[self.X_labels == dominant_label].z).min()
# find hight of the tree in the prediction
hight_of_tree_pred = (self.instance_segmented_las[self.Y_labels == label].z).max() - (self.instance_segmented_las[self.Y_labels == label].z).min()
# get abs resiudal of the hight of the tree in the prediction
residual_hight_of_tree_pred = abs(hight_of_tree_gt - hight_of_tree_pred)
# create tmp dict
tmp_dict = {
'pred_label': label,
'gt_label(dominant_label)': dominant_label,
'high_of_tree_gt': hight_of_tree_gt,
'high_of_tree_pred': hight_of_tree_pred,
'residual_hight(gt_minus_pred)': residual_hight_of_tree_pred,
'sum_all': sum_all,
'true_positive': true_positive,
'false_positive': false_positive,
'false_negative': false_negative,
'true_negative': true_negative,
'precision': precision,
'recall': recall,
'f1_score': f1_score,
'IoU': IoU,
}
metric_dict[str(label)] = tmp_dict
# list of interesting metrics
interesting_parameters = ['precision', 'recall', 'f1_score', 'IoU', 'residual_hight(gt_minus_pred)']
# weight the metrics by tree hight
metric_dict_weighted_by_tree_hight = {}
# itialize the metric_dict_weighted_by_tree_hight
for parameter in interesting_parameters:
metric_dict_weighted_by_tree_hight[parameter] = 0
# do this if there is at least one label
if metric_dict:
for label in metric_dict.keys():
for parameter in interesting_parameters:
metric_dict_weighted_by_tree_hight[parameter] += metric_dict[label]['high_of_tree_gt'] * metric_dict[label][parameter]
# divide by the sum of the hights of the trees
for parameter in interesting_parameters:
metric_dict_weighted_by_tree_hight[parameter] /= sum([metric_dict[label]['high_of_tree_gt'] for label in metric_dict.keys()])
# compute the mean of the metrics
metric_dict_mean = {}
for parameter in interesting_parameters:
metric_dict_mean[parameter] = 0
# do this if there is at least one label
if metric_dict:
for key, value in metric_dict.items():
for parameter in interesting_parameters:
metric_dict_mean[parameter] += value[parameter]
for parameter in interesting_parameters:
metric_dict_mean[parameter] = metric_dict_mean[parameter] / len(metric_dict)
# compute tree level metrics
if metric_dict:
# get the number of trees in the ground truth
gt_trees = np.unique(self.input_las[self.GT_LABEL_NAME])
# remove 0 from gt_trees
gt_trees = gt_trees[gt_trees != 0]
# get the number of trees that are predicted correctly
trees_predicted = np.unique([metric_dict[key]['gt_label(dominant_label)'] for key in metric_dict.keys()])
# iterate over metric_dict and get the number of trees that are predicted correctly with IoU > 0.5
trees_correctly_predicted_IoU = np.unique([metric_dict[key]['gt_label(dominant_label)'] for key in metric_dict.keys() if metric_dict[key]['IoU'] > 0.5])
# convert to set
gt_trees = set(gt_trees)
trees_predicted = set(trees_predicted)
trees_correctly_predicted_IoU = set(trees_correctly_predicted_IoU)
tree_level_metric = {
'true_positve (detection rate)': len(trees_correctly_predicted_IoU) / len(gt_trees),
'false_positve (commission)': len(trees_predicted - trees_correctly_predicted_IoU) / len(gt_trees),
'false_negative (omissions)': len(gt_trees - trees_predicted - trees_correctly_predicted_IoU) / len(gt_trees),
'gt': len(gt_trees)}
# add tree level metrics to the metric_dict_mean
metric_dict_mean.update(tree_level_metric)
if self.verbose:
print('Tree level metrics:')
print(f'Trees in the ground truth: {gt_trees}')
print(f'Trees correctly predicted: {trees_predicted}')
print(f'Trees correctly predicted with IoU > 0.5: {trees_correctly_predicted_IoU}')
print(tree_level_metric)
return metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean
def print_metrics(self, metric_dict):
for key, value in metric_dict.items():
print(f'Label: {key}')
for key2, value2 in value.items():
print(f'{key2}: {value2}')
print('')
def save_to_csv_file(self, metric_dict):
df = pd.DataFrame(metric_dict).T
# save to csv file and show the index
df.to_csv(self.csv_file_name, index=True, header=True)
def main(self):
metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean = self.compute_metrics()
if self.verbose:
f1_weighted_by_tree_hight = metric_dict_weighted_by_tree_hight['f1_score']
print(f'f1_score_weighted: {f1_weighted_by_tree_hight}')
for key, value in metric_dict.items():
print(f'Label: {key}')
for key2, value2 in value.items():
print(f'{key2}: {value2}')
print('')
if self.csv_file_name is not None:
self.save_to_csv_file(metric_dict)
if self.verbose:
print(f'Metrics saved to {self.csv_file_name}')
return metric_dict, metric_dict_weighted_by_tree_hight, metric_dict_mean
# main
if __name__ == '__main__':
# do argparse input_file_path, instance_segmented_file_path, verbose
parser = argparse.ArgumentParser()
parser.add_argument('--input_file_path', type=str, required=True)
parser.add_argument('--instance_segmented_file_path', type=str, required=True)
parser.add_argument('--remove_ground', action='store_true', help="Do not take into account the ground (class 0).", default=False)
parser.add_argument('--csv_file_name', type=str, help="Name of the csv file to save the metrics to", default=None)
parser.add_argument('--verbose', action='store_true', help="Print information about the process", default=False)
args = parser.parse_args()
# create instance of the class InstanceSegmentationMetrics
instance_segmentation_metrics = InstanceSegmentationMetrics(
args.input_file_path,
args.instance_segmented_file_path,
args.remove_ground,
args.csv_file_name,
args.verbose
)
# compute metrics
metric_dict, _, _ = instance_segmentation_metrics.main()