Skip to content
Snippets Groups Projects
Commit 781bd59b authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

updated instance segmentation metrics with params

parent e797bc8d
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,8 @@ from sklearn.neighbors import KDTree ...@@ -9,6 +9,8 @@ from sklearn.neighbors import KDTree
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
class InstanceSegmentationMetrics: class InstanceSegmentationMetrics:
GT_LABEL_NAME = 'treeID' #GT_LABEL_NAME = 'StemID'
TARGET_LABEL_NAME = 'instance_nr'
def __init__( def __init__(
self, self,
input_file_path, input_file_path,
...@@ -27,13 +29,16 @@ class InstanceSegmentationMetrics: ...@@ -27,13 +29,16 @@ class InstanceSegmentationMetrics:
self.input_las = laspy.read(self.input_file_path) self.input_las = laspy.read(self.input_file_path)
self.instance_segmented_las = laspy.read(self.instance_segmented_file_path) self.instance_segmented_las = laspy.read(self.instance_segmented_file_path)
self.skip_flag = self.check_if_labels_exist() self.skip_flag = self.check_if_labels_exist(
X_label=self.GT_LABEL_NAME,
Y_label=self.TARGET_LABEL_NAME
)
if not self.skip_flag: if not self.skip_flag:
# get labels from input las file # get labels from input las file
self.X_labels = self.input_las.treeID.astype(int) #TODO: generalize this to other labels self.X_labels = self.input_las[self.GT_LABEL_NAME].astype(int)
# get labels from instance segmented las file # get labels from instance segmented las file
self.Y_labels = self.instance_segmented_las.instance_nr.astype(int) #TODO: generalize this to other labels self.Y_labels = self.instance_segmented_las[self.TARGET_LABEL_NAME].astype(int)
# if self.remove_ground: # if self.remove_ground:
# # the labeling starts from 0, so we need to remove the ground # # the labeling starts from 0, so we need to remove the ground
# self.Y_labels += 1 # self.Y_labels += 1
...@@ -126,12 +131,12 @@ class InstanceSegmentationMetrics: ...@@ -126,12 +131,12 @@ class InstanceSegmentationMetrics:
# define a function that finds class in input_file with the most points # define a function that finds class in input_file with the most points
def find_dominant_classes_in_gt(self, input_file): def find_dominant_classes_in_gt(self, input_file):
# get the unique labels # get the unique labels
unique_labels = np.unique(input_file.treeID).astype(int) unique_labels = np.unique(input_file[self.GT_LABEL_NAME]).astype(int)
# create a dictionary # create a dictionary
tmp_dict = {} tmp_dict = {}
for label in unique_labels: for label in unique_labels:
# get the indices of input_file.treeID == label # get the indices of input_file.treeID == label
ind_label = np.where(input_file.treeID == label)[0] ind_label = np.where(input_file[self.GT_LABEL_NAME] == label)[0]
# put the number of points to the tmp_dict # put the number of points to the tmp_dict
tmp_dict[str(label)] = ind_label.shape[0] tmp_dict[str(label)] = ind_label.shape[0]
# sort tmp_dict by the number of points # sort tmp_dict by the number of points
...@@ -331,7 +336,7 @@ class InstanceSegmentationMetrics: ...@@ -331,7 +336,7 @@ class InstanceSegmentationMetrics:
# compute tree level metrics # compute tree level metrics
if metric_dict: if metric_dict:
# get the number of trees in the ground truth # get the number of trees in the ground truth
gt_trees = np.unique(self.input_las.treeID) gt_trees = np.unique(self.input_las[self.GT_LABEL_NAME])
# remove 0 from gt_trees # remove 0 from gt_trees
gt_trees = gt_trees[gt_trees != 0] gt_trees = gt_trees[gt_trees != 0]
......
...@@ -9,6 +9,9 @@ from metrics.instance_segmentation_metrics import InstanceSegmentationMetrics ...@@ -9,6 +9,9 @@ from metrics.instance_segmentation_metrics import InstanceSegmentationMetrics
from nibio_postprocessing.attach_labels_to_las_file import AttachLabelsToLasFile from nibio_postprocessing.attach_labels_to_las_file import AttachLabelsToLasFile
class InstanceSegmentationMetricsInFolder(): class InstanceSegmentationMetricsInFolder():
GT_LABEL_NAME = 'treeID'
TARGET_LABEL_NAME = 'instance_nr'
def __init__( def __init__(
self, self,
gt_las_folder_path, gt_las_folder_path,
...@@ -161,8 +164,8 @@ class InstanceSegmentationMetricsInFolder(): ...@@ -161,8 +164,8 @@ class InstanceSegmentationMetricsInFolder():
gt_las_file_path, gt_las_file_path,
target_las_file_path, target_las_file_path,
update_las_file_path = os.path.join(self.output_folder_path, gt_las_file_core_name + '.las'), update_las_file_path = os.path.join(self.output_folder_path, gt_las_file_core_name + '.las'),
gt_label_name='treeID', gt_label_name=self.GT_LABEL_NAME,
target_label_name='treeID', target_label_name=self.GT_LABEL_NAME,
verbose=self.verbose verbose=self.verbose
).main() ).main()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment