From 3f8ebc4e0913ec5547296b1d1376e32648394b16 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Wed, 12 Oct 2022 15:49:41 +0200 Subject: [PATCH] update of metrics --- metrics/instance_segmentation_metrics.py | 62 +++++++++++++++++++++--- 1 file changed, 54 insertions(+), 8 deletions(-) diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py index aa94bf6..80a76cf 100644 --- a/metrics/instance_segmentation_metrics.py +++ b/metrics/instance_segmentation_metrics.py @@ -22,6 +22,8 @@ class InstanceSegmentationMetrics(): # get different classes from gt and pred gt_classes = np.unique(las_gt.treeID) + # remove 0 from gt_classes as it is the background class + gt_classes = gt_classes[gt_classes != 0] pred_classes = np.unique(las_pred.instance_nr) # put x, y, z, for different classes in a dictionary @@ -32,22 +34,66 @@ class InstanceSegmentationMetrics(): for pred_class in pred_classes: pred_dict[pred_class] = np.vstack((las_pred.x[las_pred.instance_nr == pred_class], las_pred.y[las_pred.instance_nr == pred_class], las_pred.z[las_pred.instance_nr == pred_class])).T + # get the number of points in gt and pred per class and put it in a dictionary + gt_dict_points = {} + pred_dict_points = {} + for gt_class in gt_classes: + gt_dict_points[gt_class] = gt_dict[gt_class].shape[0] + for pred_class in pred_classes: + pred_dict_points[pred_class] = pred_dict[pred_class].shape[0] + # compute overlap for each class overlap = {} for gt_class in gt_dict: for pred_class in pred_dict: overlap[(gt_class, pred_class)] = self.get_overlap(gt_dict[gt_class], pred_dict[pred_class]) - # print the number of classes in gt and pred - logging.info('Number of classes in gt: {}'.format(len(gt_classes))) - logging.info('Number of classes in pred: {}'.format(len(pred_classes))) + # get number of overlapping points per class + overlap_points = {} + for gt_class in gt_dict: + for pred_class in pred_dict: + overlap_points[(gt_class, pred_class)] = np.sum(overlap[(gt_class, pred_class)]) + + # sort the overlap points in descending order + overlap_points = {k: v for k, v in sorted(overlap_points.items(), key=lambda item: item[1], reverse=True)} + + # sort out overlaps by the number of points in gt and pred + sorted_overlap = sorted(overlap.items(), key=lambda x: x[1], reverse=True) + sorted_overlap_points = sorted(overlap_points.items(), key=lambda x: x[1], reverse=True) + + # # print the number of classes in gt and pred + # logging.info('Number of classes in gt: {}'.format(len(gt_classes))) + # logging.info('Number of classes in pred: {}'.format(len(pred_classes))) + + # # print the number of points in gt and pred + # logging.info('Number of points in gt: {}'.format(sum(gt_dict_points.values()))) + # logging.info('Number of points in pred: {}'.format(sum(pred_dict_points.values()))) + + # # print the number of points in gt and pred per class + # logging.info('Number of points in gt per class: {}'.format(gt_dict_points)) + # logging.info('Number of points in pred per class: {}'.format(pred_dict_points)) + + # # print the number of overlapping points per class + # logging.info('Number of overlapping points per class: {}'.format(overlap_points)) + + # # print sorted overlap + # logging.info('Sorted overlap: {}'.format(sorted_overlap)) + + # find overlap between gt 39 and pred 6 + logging.info('Overlap between gt 39 and pred 6: {}'.format(overlap[(39, 6)])) + + # find overlap between gt 39 and pred 2 + logging.info('Overlap between gt 39 and pred 2: {}'.format(overlap[(39, 2)])) + + + # print sorted overlap for first 10 classes + # logging.info('Sorted overlap for classes: {}'.format(sorted_overlap)) + + # # print best match for classes along with overlap + # logging.info('Best match for classes: {}'.format(best_match)) + - # print first 10 classes in gt and pred - logging.info('First 10 classes in gt: {}'.format(gt_classes[:10])) - logging.info('First 10 classes in pred: {}'.format(pred_classes[:10])) - # print overlap for first 10 classes - logging.info('Overlap for first 10 classes: {}'.format(overlap)) def get_overlap(self, gt, pred): -- GitLab