diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py index 83e97a7baceaabe791168b15575835d1359ad690..34c715d2f9f8b954f526eba0c7c3609dee07fb67 100644 --- a/metrics/instance_segmentation_metrics.py +++ b/metrics/instance_segmentation_metrics.py @@ -9,6 +9,8 @@ from sklearn.neighbors import KDTree logging.basicConfig(level=logging.INFO) class InstanceSegmentationMetrics: + GT_LABEL_NAME = 'treeID' #GT_LABEL_NAME = 'StemID' + TARGET_LABEL_NAME = 'instance_nr' def __init__( self, input_file_path, @@ -27,13 +29,16 @@ class InstanceSegmentationMetrics: self.input_las = laspy.read(self.input_file_path) self.instance_segmented_las = laspy.read(self.instance_segmented_file_path) - self.skip_flag = self.check_if_labels_exist() + self.skip_flag = self.check_if_labels_exist( + X_label=self.GT_LABEL_NAME, + Y_label=self.TARGET_LABEL_NAME + ) if not self.skip_flag: # get labels from input las file - self.X_labels = self.input_las.treeID.astype(int) #TODO: generalize this to other labels + self.X_labels = self.input_las[self.GT_LABEL_NAME].astype(int) # get labels from instance segmented las file - self.Y_labels = self.instance_segmented_las.instance_nr.astype(int) #TODO: generalize this to other labels + self.Y_labels = self.instance_segmented_las[self.TARGET_LABEL_NAME].astype(int) # if self.remove_ground: # # the labeling starts from 0, so we need to remove the ground # self.Y_labels += 1 @@ -126,12 +131,12 @@ class InstanceSegmentationMetrics: # define a function that finds class in input_file with the most points def find_dominant_classes_in_gt(self, input_file): # get the unique labels - unique_labels = np.unique(input_file.treeID).astype(int) + unique_labels = np.unique(input_file[self.GT_LABEL_NAME]).astype(int) # create a dictionary tmp_dict = {} for label in unique_labels: # get the indices of input_file.treeID == label - ind_label = np.where(input_file.treeID == label)[0] + ind_label = np.where(input_file[self.GT_LABEL_NAME] == label)[0] # put the number of points to the tmp_dict tmp_dict[str(label)] = ind_label.shape[0] # sort tmp_dict by the number of points @@ -331,7 +336,7 @@ class InstanceSegmentationMetrics: # compute tree level metrics if metric_dict: # get the number of trees in the ground truth - gt_trees = np.unique(self.input_las.treeID) + gt_trees = np.unique(self.input_las[self.GT_LABEL_NAME]) # remove 0 from gt_trees gt_trees = gt_trees[gt_trees != 0] diff --git a/metrics/instance_segmentation_metrics_in_folder.py b/metrics/instance_segmentation_metrics_in_folder.py index 8ddd37be3a61472450ce7d5a872675c13457d7e9..ccc57c93702e3996b1c5281e5c8ae1d85d458ab8 100644 --- a/metrics/instance_segmentation_metrics_in_folder.py +++ b/metrics/instance_segmentation_metrics_in_folder.py @@ -9,6 +9,9 @@ from metrics.instance_segmentation_metrics import InstanceSegmentationMetrics from nibio_postprocessing.attach_labels_to_las_file import AttachLabelsToLasFile class InstanceSegmentationMetricsInFolder(): + GT_LABEL_NAME = 'treeID' + TARGET_LABEL_NAME = 'instance_nr' + def __init__( self, gt_las_folder_path, @@ -161,8 +164,8 @@ class InstanceSegmentationMetricsInFolder(): gt_las_file_path, target_las_file_path, update_las_file_path = os.path.join(self.output_folder_path, gt_las_file_core_name + '.las'), - gt_label_name='treeID', - target_label_name='treeID', + gt_label_name=self.GT_LABEL_NAME, + target_label_name=self.GT_LABEL_NAME, verbose=self.verbose ).main()