From d42f8a673ab7a647df4a482ad8b2f8137644bbd5 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Wed, 12 Oct 2022 12:29:23 +0200 Subject: [PATCH] overlap implemented --- metrics/instance_segmentation_metrics.py | 28 ++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py index 9b5e1af..aa94bf6 100644 --- a/metrics/instance_segmentation_metrics.py +++ b/metrics/instance_segmentation_metrics.py @@ -24,10 +24,38 @@ class InstanceSegmentationMetrics(): gt_classes = np.unique(las_gt.treeID) pred_classes = np.unique(las_pred.instance_nr) + # put x, y, z, for different classes in a dictionary + gt_dict = {} + pred_dict = {} + for gt_class in gt_classes: + gt_dict[gt_class] = np.vstack((las_gt.x[las_gt.treeID == gt_class], las_gt.y[las_gt.treeID == gt_class], las_gt.z[las_gt.treeID == gt_class])).T + for pred_class in pred_classes: + pred_dict[pred_class] = np.vstack((las_pred.x[las_pred.instance_nr == pred_class], las_pred.y[las_pred.instance_nr == pred_class], las_pred.z[las_pred.instance_nr == pred_class])).T + + # compute overlap for each class + overlap = {} + for gt_class in gt_dict: + for pred_class in pred_dict: + overlap[(gt_class, pred_class)] = self.get_overlap(gt_dict[gt_class], pred_dict[pred_class]) + # print the number of classes in gt and pred logging.info('Number of classes in gt: {}'.format(len(gt_classes))) logging.info('Number of classes in pred: {}'.format(len(pred_classes))) + # print first 10 classes in gt and pred + logging.info('First 10 classes in gt: {}'.format(gt_classes[:10])) + logging.info('First 10 classes in pred: {}'.format(pred_classes[:10])) + + # print overlap for first 10 classes + logging.info('Overlap for first 10 classes: {}'.format(overlap)) + + + def get_overlap(self, gt, pred): + # compute overlap between gt and pred + overlap = np.intersect1d(gt, pred).shape[0] + # overlap = np.sum(np.all(gt[:, None] == pred, axis=-1), axis=0) + return overlap + def get_metrics_for_all_point_clouds(self): # get all las files in gt and pred folders using glob las_gt = glob.glob(os.path.join(self.gt_folder, '*.las')) -- GitLab