Skip to content
Snippets Groups Projects
Commit c0c14045 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

bugs fixed in instance segmentation metrics

parent 52264d67
No related branches found
No related tags found
No related merge requests found
......@@ -9,8 +9,8 @@ mv $TARGET_FOLDER/small_file_pipeline_test.las $TARGET_FOLDER/first.las
# make a copy of the file
cp $TARGET_FOLDER/first.las $TARGET_FOLDER/second.las
# # make a copy of the file
# cp $TARGET_FOLDER/first.las $TARGET_FOLDER/third.las
# make a copy of the file
cp $TARGET_FOLDER/first.las $TARGET_FOLDER/third.las
# # make a copy of the file
# cp $TARGET_FOLDER/first.las $TARGET_FOLDER/fourth.las
\ No newline at end of file
# make a copy of the file
cp $TARGET_FOLDER/first.las $TARGET_FOLDER/fourth.las
\ No newline at end of file
......@@ -15,12 +15,14 @@ class InstanceSegmentationMetrics:
self,
input_file_path,
instance_segmented_file_path,
remove_ground = False,
save_to_csv=False,
verbose=False
) -> None:
self.input_file_path = input_file_path
self.instance_segmented_file_path = instance_segmented_file_path
self.remove_ground = remove_ground
self.save_to_csv = save_to_csv
self.verbose = verbose
# read and prepare input las file and instance segmented las file
......@@ -30,6 +32,11 @@ class InstanceSegmentationMetrics:
self.X_labels = self.input_las.treeID.astype(int) #TODO: generalize this to other labels
# get labels from instance segmented las file
self.Y_labels = self.instance_segmented_las.instance_nr.astype(int) #TODO: generalize this to other labels
if self.remove_ground:
# the labeling starts from 0, so we need to remove the ground
self.Y_labels += 1
# do knn mapping
self.dict_Y = self.do_knn_mapping()
def do_knn_mapping(self):
......@@ -113,11 +120,15 @@ class InstanceSegmentationMetrics:
tmp_dict[str(label)] = ind_label.shape[0]
# sort tmp_dict by the number of points
tmp_dict_sorted = {k: v for k, v in sorted(tmp_dict.items(), key=lambda item: item[1], reverse=True)}
# remove key 0 from tmp_dict_sorted
if self.remove_ground:
tmp_dict_sorted.pop('0', None)
return tmp_dict_sorted.keys()
def get_the_dominant_label(self, dominant_labels_sorted):
# get the dominant label
# iterate over the dominant_labels_sorted and sort it based on the first value of sub-dictionary
# if sub-dictionary is empty, remove the key from the dictionary
......@@ -159,16 +170,36 @@ class InstanceSegmentationMetrics:
label_mapping_dict = {}
dominant_labels_sorted = self.get_dominant_lables_sorted()
gt_classes_to_iterate = self. find_dominant_classes_in_gt(self.input_las)
gt_classes_to_iterate = self.find_dominant_classes_in_gt(self.input_las)
for gt_class in gt_classes_to_iterate:
# if all the values in dominant_labels_sorted are empty, break the loop
if self.remove_ground:
# check if all the sub-dictionaries have only one key and it is 0
if all(len(v) == 1 and '0' in v for v in dominant_labels_sorted.values()):
break
if not any(dominant_labels_sorted.values()):
break
if len(dominant_labels_sorted) == 1:
dominant_label_key, dominant_label = self.get_the_dominant_label(dominant_labels_sorted)
label_mapping_dict[dominant_label_key] = dominant_label
break
dominant_label_key, dominant_label = self.get_the_dominant_label(
self.extract_from_sub_dict(dominant_labels_sorted, gt_class))
extracted = self.extract_from_sub_dict(dominant_labels_sorted, gt_class)
# if extracted is empty, continue
if not extracted:
continue
# if all the values in extracted are empty, continue
if not any(extracted.values()):
continue
dominant_label_key, dominant_label = self.get_the_dominant_label(extracted)
self.remove_dominant_label(dominant_labels_sorted, dominant_label_key, dominant_label)
......@@ -182,11 +213,9 @@ class InstanceSegmentationMetrics:
def compute_metrics(self):
# get the label_mapping_dict
label_mapping_dict = self.iterate_over_pc()
Y_unique_labels = np.unique(self.Y_labels)
# map the labels
metric_dict = {}
for label in Y_unique_labels:
for label in list(label_mapping_dict.keys()):
# get the indices of Y_labels == label
ind_Y_labels_label = np.where(self.Y_labels == label)[0]
......@@ -247,9 +276,8 @@ class InstanceSegmentationMetrics:
for key, value in metric_dict.items():
f1_score_weighted += value['f1_score'] * value['high_of_tree']
f1_score_weighted = f1_score_weighted / self.input_las.z.max()
f1_score_weighted = f1_score_weighted / len(Y_unique_labels)
# compute f1_score_weighted by dividing by the sum of the hights of the trees
f1_score_weighted = f1_score_weighted / sum([value['high_of_tree'] for key, value in metric_dict.items()])
return metric_dict, f1_score_weighted
......@@ -286,6 +314,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_file_path', type=str, required=True)
parser.add_argument('--instance_segmented_file_path', type=str, required=True)
parser.add_argument('--remove_ground', action='store_true', help="Do not take into account the ground (class 0).", default=False)
parser.add_argument('--save_to_csv', action='store_true', help="Save the metrics to a csv file", default=False)
parser.add_argument('--verbose', action='store_true', help="Print information about the process", default=False)
......@@ -295,6 +324,7 @@ if __name__ == '__main__':
instance_segmentation_metrics = InstanceSegmentationMetrics(
args.input_file_path,
args.instance_segmented_file_path,
args.remove_ground,
args.save_to_csv,
args.verbose
)
......
import glob
import os
import laspy
from metrics.instance_segmentation_metrics import InstanceSegmentationMetrics
class InstanceSegmentationMetricsInFolder():
......@@ -7,10 +8,12 @@ class InstanceSegmentationMetricsInFolder():
self,
gt_las_folder_path,
target_las_folder_path,
remove_ground=False,
verbose=False
):
self.gt_las_folder_path = gt_las_folder_path
self.target_las_folder_path = target_las_folder_path
self.remove_ground = remove_ground
self.verbose = verbose
def main(self):
......@@ -29,6 +32,30 @@ class InstanceSegmentationMetricsInFolder():
# match the core name in the gt_las_file_path and target_las_file_path and make tuples of the matched paths
matched_paths = []
for gt_las_file_path, target_las_file_path in zip(gt_las_file_paths, target_las_file_paths):
# read the las file check if las file is not empty
gt_las_file = laspy.read(gt_las_file_path)
if len(gt_las_file.points) == 0:
# remove the las file and the corresponding target_las_file_path
os.remove(gt_las_file_path)
os.remove(target_las_file_path)
if self.verbose:
print('Removed empty las file: ' + gt_las_file_path)
# check if las file is not empty
target_las_file = laspy.read(target_las_file_path)
if len(target_las_file.points) == 0:
# remove the las file and the corresponding target_las_file_path
os.remove(gt_las_file_path)
os.remove(target_las_file_path)
if self.verbose:
print('Removed empty las file: ' + target_las_file_path)
# print what files are being matched and processed
if self.verbose:
print('Matching: ' + gt_las_file_path + ' and ' + target_las_file_path)
# get the core name of the gt_las_file_path
gt_las_file_core_name = os.path.basename(gt_las_file_path).split('.')[0]
# get the core name of the target_las_file_path
......@@ -54,10 +81,14 @@ class InstanceSegmentationMetricsInFolder():
# check that the core name of the gt_las_file_path and target_las_file_path are the same
if gt_las_file_core_name == target_las_file_core_name:
if self.verbose:
print('Processing: ' + gt_las_file_path + ' and ' + target_las_file_path)
# run the instance segmentation metrics
instance_segmentation_metrics = InstanceSegmentationMetrics(
gt_las_file_path,
target_las_file_path,
remove_ground=self.remove_ground,
verbose=self.verbose
)
_, f1_score_weighted = instance_segmentation_metrics.main()
......@@ -76,6 +107,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gt_las_folder_path', type=str, required=True)
parser.add_argument('--target_las_folder_path', type=str, required=True)
parser.add_argument('--remove_ground', action='store_true', help="Do not take into account the ground (class 0).", default=False)
parser.add_argument('--verbose', action='store_true', help="Print information about the process")
args = parser.parse_args()
......@@ -83,6 +115,7 @@ if __name__ == '__main__':
instance_segmentation_metrics_in_folder = InstanceSegmentationMetricsInFolder(
args.gt_las_folder_path,
args.target_las_folder_path,
args.remove_ground,
verbose=args.verbose
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment