diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py
index 9b5e1af7035e145a6f2982802b9007a98162bbce..aa94bf67fc528ec8bdc034a30a01e7df944223b7 100644
--- a/metrics/instance_segmentation_metrics.py
+++ b/metrics/instance_segmentation_metrics.py
@@ -24,10 +24,38 @@ class InstanceSegmentationMetrics():
         gt_classes = np.unique(las_gt.treeID)
         pred_classes = np.unique(las_pred.instance_nr)
 
+        # put x, y, z, for different classes in a dictionary
+        gt_dict = {}
+        pred_dict = {}
+        for gt_class in gt_classes:
+            gt_dict[gt_class] = np.vstack((las_gt.x[las_gt.treeID == gt_class], las_gt.y[las_gt.treeID == gt_class], las_gt.z[las_gt.treeID == gt_class])).T
+        for pred_class in pred_classes:
+            pred_dict[pred_class] = np.vstack((las_pred.x[las_pred.instance_nr == pred_class], las_pred.y[las_pred.instance_nr == pred_class], las_pred.z[las_pred.instance_nr == pred_class])).T
+
+        # compute overlap for each class
+        overlap = {}
+        for gt_class in gt_dict:
+            for pred_class in pred_dict:
+                overlap[(gt_class, pred_class)] = self.get_overlap(gt_dict[gt_class], pred_dict[pred_class])
+
         # print the number of classes in gt and pred
         logging.info('Number of classes in gt: {}'.format(len(gt_classes)))
         logging.info('Number of classes in pred: {}'.format(len(pred_classes)))
 
+        # print first 10 classes in gt and pred
+        logging.info('First 10 classes in gt: {}'.format(gt_classes[:10]))
+        logging.info('First 10 classes in pred: {}'.format(pred_classes[:10]))
+
+        # print overlap for first 10 classes
+        logging.info('Overlap for first 10 classes: {}'.format(overlap))
+
+
+    def get_overlap(self, gt, pred):
+        # compute overlap between gt and pred
+        overlap = np.intersect1d(gt, pred).shape[0]
+        # overlap = np.sum(np.all(gt[:, None] == pred, axis=-1), axis=0)
+        return overlap
+
     def get_metrics_for_all_point_clouds(self):
         # get all las files in gt and pred folders using glob
         las_gt = glob.glob(os.path.join(self.gt_folder, '*.las'))