diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py index aa94bf67fc528ec8bdc034a30a01e7df944223b7..80a76cfae145c25333b849283e467a594ac3c5ed 100644 --- a/metrics/instance_segmentation_metrics.py +++ b/metrics/instance_segmentation_metrics.py @@ -22,6 +22,8 @@ class InstanceSegmentationMetrics(): # get different classes from gt and pred gt_classes = np.unique(las_gt.treeID) + # remove 0 from gt_classes as it is the background class + gt_classes = gt_classes[gt_classes != 0] pred_classes = np.unique(las_pred.instance_nr) # put x, y, z, for different classes in a dictionary @@ -32,22 +34,66 @@ class InstanceSegmentationMetrics(): for pred_class in pred_classes: pred_dict[pred_class] = np.vstack((las_pred.x[las_pred.instance_nr == pred_class], las_pred.y[las_pred.instance_nr == pred_class], las_pred.z[las_pred.instance_nr == pred_class])).T + # get the number of points in gt and pred per class and put it in a dictionary + gt_dict_points = {} + pred_dict_points = {} + for gt_class in gt_classes: + gt_dict_points[gt_class] = gt_dict[gt_class].shape[0] + for pred_class in pred_classes: + pred_dict_points[pred_class] = pred_dict[pred_class].shape[0] + # compute overlap for each class overlap = {} for gt_class in gt_dict: for pred_class in pred_dict: overlap[(gt_class, pred_class)] = self.get_overlap(gt_dict[gt_class], pred_dict[pred_class]) - # print the number of classes in gt and pred - logging.info('Number of classes in gt: {}'.format(len(gt_classes))) - logging.info('Number of classes in pred: {}'.format(len(pred_classes))) + # get number of overlapping points per class + overlap_points = {} + for gt_class in gt_dict: + for pred_class in pred_dict: + overlap_points[(gt_class, pred_class)] = np.sum(overlap[(gt_class, pred_class)]) + + # sort the overlap points in descending order + overlap_points = {k: v for k, v in sorted(overlap_points.items(), key=lambda item: item[1], reverse=True)} + + # sort out overlaps by the number of points in gt and pred + sorted_overlap = sorted(overlap.items(), key=lambda x: x[1], reverse=True) + sorted_overlap_points = sorted(overlap_points.items(), key=lambda x: x[1], reverse=True) + + # # print the number of classes in gt and pred + # logging.info('Number of classes in gt: {}'.format(len(gt_classes))) + # logging.info('Number of classes in pred: {}'.format(len(pred_classes))) + + # # print the number of points in gt and pred + # logging.info('Number of points in gt: {}'.format(sum(gt_dict_points.values()))) + # logging.info('Number of points in pred: {}'.format(sum(pred_dict_points.values()))) + + # # print the number of points in gt and pred per class + # logging.info('Number of points in gt per class: {}'.format(gt_dict_points)) + # logging.info('Number of points in pred per class: {}'.format(pred_dict_points)) + + # # print the number of overlapping points per class + # logging.info('Number of overlapping points per class: {}'.format(overlap_points)) + + # # print sorted overlap + # logging.info('Sorted overlap: {}'.format(sorted_overlap)) + + # find overlap between gt 39 and pred 6 + logging.info('Overlap between gt 39 and pred 6: {}'.format(overlap[(39, 6)])) + + # find overlap between gt 39 and pred 2 + logging.info('Overlap between gt 39 and pred 2: {}'.format(overlap[(39, 2)])) + + + # print sorted overlap for first 10 classes + # logging.info('Sorted overlap for classes: {}'.format(sorted_overlap)) + + # # print best match for classes along with overlap + # logging.info('Best match for classes: {}'.format(best_match)) + - # print first 10 classes in gt and pred - logging.info('First 10 classes in gt: {}'.format(gt_classes[:10])) - logging.info('First 10 classes in pred: {}'.format(pred_classes[:10])) - # print overlap for first 10 classes - logging.info('Overlap for first 10 classes: {}'.format(overlap)) def get_overlap(self, gt, pred):