Skip to content
Snippets Groups Projects
Commit d42f8a67 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

overlap implemented

parent dfda5966
No related branches found
No related tags found
No related merge requests found
...@@ -24,10 +24,38 @@ class InstanceSegmentationMetrics(): ...@@ -24,10 +24,38 @@ class InstanceSegmentationMetrics():
gt_classes = np.unique(las_gt.treeID) gt_classes = np.unique(las_gt.treeID)
pred_classes = np.unique(las_pred.instance_nr) pred_classes = np.unique(las_pred.instance_nr)
# put x, y, z, for different classes in a dictionary
gt_dict = {}
pred_dict = {}
for gt_class in gt_classes:
gt_dict[gt_class] = np.vstack((las_gt.x[las_gt.treeID == gt_class], las_gt.y[las_gt.treeID == gt_class], las_gt.z[las_gt.treeID == gt_class])).T
for pred_class in pred_classes:
pred_dict[pred_class] = np.vstack((las_pred.x[las_pred.instance_nr == pred_class], las_pred.y[las_pred.instance_nr == pred_class], las_pred.z[las_pred.instance_nr == pred_class])).T
# compute overlap for each class
overlap = {}
for gt_class in gt_dict:
for pred_class in pred_dict:
overlap[(gt_class, pred_class)] = self.get_overlap(gt_dict[gt_class], pred_dict[pred_class])
# print the number of classes in gt and pred # print the number of classes in gt and pred
logging.info('Number of classes in gt: {}'.format(len(gt_classes))) logging.info('Number of classes in gt: {}'.format(len(gt_classes)))
logging.info('Number of classes in pred: {}'.format(len(pred_classes))) logging.info('Number of classes in pred: {}'.format(len(pred_classes)))
# print first 10 classes in gt and pred
logging.info('First 10 classes in gt: {}'.format(gt_classes[:10]))
logging.info('First 10 classes in pred: {}'.format(pred_classes[:10]))
# print overlap for first 10 classes
logging.info('Overlap for first 10 classes: {}'.format(overlap))
def get_overlap(self, gt, pred):
# compute overlap between gt and pred
overlap = np.intersect1d(gt, pred).shape[0]
# overlap = np.sum(np.all(gt[:, None] == pred, axis=-1), axis=0)
return overlap
def get_metrics_for_all_point_clouds(self): def get_metrics_for_all_point_clouds(self):
# get all las files in gt and pred folders using glob # get all las files in gt and pred folders using glob
las_gt = glob.glob(os.path.join(self.gt_folder, '*.las')) las_gt = glob.glob(os.path.join(self.gt_folder, '*.las'))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment