Skip to content
Snippets Groups Projects
Commit 781bd59b authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

updated instance segmentation metrics with params

parent e797bc8d
Branches
Tags
No related merge requests found
......@@ -9,6 +9,8 @@ from sklearn.neighbors import KDTree
logging.basicConfig(level=logging.INFO)
class InstanceSegmentationMetrics:
GT_LABEL_NAME = 'treeID' #GT_LABEL_NAME = 'StemID'
TARGET_LABEL_NAME = 'instance_nr'
def __init__(
self,
input_file_path,
......@@ -27,13 +29,16 @@ class InstanceSegmentationMetrics:
self.input_las = laspy.read(self.input_file_path)
self.instance_segmented_las = laspy.read(self.instance_segmented_file_path)
self.skip_flag = self.check_if_labels_exist()
self.skip_flag = self.check_if_labels_exist(
X_label=self.GT_LABEL_NAME,
Y_label=self.TARGET_LABEL_NAME
)
if not self.skip_flag:
# get labels from input las file
self.X_labels = self.input_las.treeID.astype(int) #TODO: generalize this to other labels
self.X_labels = self.input_las[self.GT_LABEL_NAME].astype(int)
# get labels from instance segmented las file
self.Y_labels = self.instance_segmented_las.instance_nr.astype(int) #TODO: generalize this to other labels
self.Y_labels = self.instance_segmented_las[self.TARGET_LABEL_NAME].astype(int)
# if self.remove_ground:
# # the labeling starts from 0, so we need to remove the ground
# self.Y_labels += 1
......@@ -126,12 +131,12 @@ class InstanceSegmentationMetrics:
# define a function that finds class in input_file with the most points
def find_dominant_classes_in_gt(self, input_file):
# get the unique labels
unique_labels = np.unique(input_file.treeID).astype(int)
unique_labels = np.unique(input_file[self.GT_LABEL_NAME]).astype(int)
# create a dictionary
tmp_dict = {}
for label in unique_labels:
# get the indices of input_file.treeID == label
ind_label = np.where(input_file.treeID == label)[0]
ind_label = np.where(input_file[self.GT_LABEL_NAME] == label)[0]
# put the number of points to the tmp_dict
tmp_dict[str(label)] = ind_label.shape[0]
# sort tmp_dict by the number of points
......@@ -331,7 +336,7 @@ class InstanceSegmentationMetrics:
# compute tree level metrics
if metric_dict:
# get the number of trees in the ground truth
gt_trees = np.unique(self.input_las.treeID)
gt_trees = np.unique(self.input_las[self.GT_LABEL_NAME])
# remove 0 from gt_trees
gt_trees = gt_trees[gt_trees != 0]
......
......@@ -9,6 +9,9 @@ from metrics.instance_segmentation_metrics import InstanceSegmentationMetrics
from nibio_postprocessing.attach_labels_to_las_file import AttachLabelsToLasFile
class InstanceSegmentationMetricsInFolder():
GT_LABEL_NAME = 'treeID'
TARGET_LABEL_NAME = 'instance_nr'
def __init__(
self,
gt_las_folder_path,
......@@ -161,8 +164,8 @@ class InstanceSegmentationMetricsInFolder():
gt_las_file_path,
target_las_file_path,
update_las_file_path = os.path.join(self.output_folder_path, gt_las_file_core_name + '.las'),
gt_label_name='treeID',
target_label_name='treeID',
gt_label_name=self.GT_LABEL_NAME,
target_label_name=self.GT_LABEL_NAME,
verbose=self.verbose
).main()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment