From dfda59666ad7b8def5b71fa77193c276ebf6fb6a Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Wed, 12 Oct 2022 12:21:46 +0200
Subject: [PATCH] starting instance segmentation metrics implementation

---
 metrics/__init__.py                      |  0
 metrics/instance_segmentation_metrics.py | 52 ++++++++++++++++++++++++
 2 files changed, 52 insertions(+)
 create mode 100644 metrics/__init__.py
 create mode 100644 metrics/instance_segmentation_metrics.py

diff --git a/metrics/__init__.py b/metrics/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py
new file mode 100644
index 0000000..9b5e1af
--- /dev/null
+++ b/metrics/instance_segmentation_metrics.py
@@ -0,0 +1,52 @@
+
+import glob
+import argparse
+import os
+import laspy
+import logging
+
+import numpy as np
+
+
+logging.basicConfig(level=logging.INFO)
+
+class InstanceSegmentationMetrics():
+    def __init__(self, gt_folder, pred_folder):
+        self.gt_folder = gt_folder
+        self.pred_folder = pred_folder
+
+    def get_metrics_for_single_point_cloud(self, las_gt, las_pred):
+        # read las files
+        las_gt = laspy.read(las_gt)
+        las_pred = laspy.read(las_pred)
+
+        # get different classes from gt and pred
+        gt_classes = np.unique(las_gt.treeID)
+        pred_classes = np.unique(las_pred.instance_nr)
+
+        # print the number of classes in gt and pred
+        logging.info('Number of classes in gt: {}'.format(len(gt_classes)))
+        logging.info('Number of classes in pred: {}'.format(len(pred_classes)))
+
+    def get_metrics_for_all_point_clouds(self):
+        # get all las files in gt and pred folders using glob
+        las_gt = glob.glob(os.path.join(self.gt_folder, '*.las'))
+        las_pred = glob.glob(os.path.join(self.pred_folder, '*.las'))
+
+        # if the number of files in gt and pred are not the same, raise an exception
+        if len(las_gt) != len(las_pred):
+            raise Exception('Number of files in gt and pred folders are not the same.')
+        
+        # iterate over all files in gt and pred folders
+        for i in range(len(las_gt)):
+            self.get_metrics_for_single_point_cloud(las_gt[i], las_pred[i])
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--gt_folder', type=str, required=True, help='Path to the folder containing ground truth point clouds.')
+    parser.add_argument('--pred_folder', type=str, required=True, help='Path to the folder containing predicted point clouds.')
+    args = parser.parse_args()
+
+    # create an instance of InstanceSegmentationMetrics class
+    instance_segmentation_metrics = InstanceSegmentationMetrics(args.gt_folder, args.pred_folder)
+    instance_segmentation_metrics.get_metrics_for_all_point_clouds()
-- 
GitLab