diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py index 9b5e1af7035e145a6f2982802b9007a98162bbce..aa94bf67fc528ec8bdc034a30a01e7df944223b7 100644 --- a/metrics/instance_segmentation_metrics.py +++ b/metrics/instance_segmentation_metrics.py @@ -24,10 +24,38 @@ class InstanceSegmentationMetrics(): gt_classes = np.unique(las_gt.treeID) pred_classes = np.unique(las_pred.instance_nr) + # put x, y, z, for different classes in a dictionary + gt_dict = {} + pred_dict = {} + for gt_class in gt_classes: + gt_dict[gt_class] = np.vstack((las_gt.x[las_gt.treeID == gt_class], las_gt.y[las_gt.treeID == gt_class], las_gt.z[las_gt.treeID == gt_class])).T + for pred_class in pred_classes: + pred_dict[pred_class] = np.vstack((las_pred.x[las_pred.instance_nr == pred_class], las_pred.y[las_pred.instance_nr == pred_class], las_pred.z[las_pred.instance_nr == pred_class])).T + + # compute overlap for each class + overlap = {} + for gt_class in gt_dict: + for pred_class in pred_dict: + overlap[(gt_class, pred_class)] = self.get_overlap(gt_dict[gt_class], pred_dict[pred_class]) + # print the number of classes in gt and pred logging.info('Number of classes in gt: {}'.format(len(gt_classes))) logging.info('Number of classes in pred: {}'.format(len(pred_classes))) + # print first 10 classes in gt and pred + logging.info('First 10 classes in gt: {}'.format(gt_classes[:10])) + logging.info('First 10 classes in pred: {}'.format(pred_classes[:10])) + + # print overlap for first 10 classes + logging.info('Overlap for first 10 classes: {}'.format(overlap)) + + + def get_overlap(self, gt, pred): + # compute overlap between gt and pred + overlap = np.intersect1d(gt, pred).shape[0] + # overlap = np.sum(np.all(gt[:, None] == pred, axis=-1), axis=0) + return overlap + def get_metrics_for_all_point_clouds(self): # get all las files in gt and pred folders using glob las_gt = glob.glob(os.path.join(self.gt_folder, '*.las'))