diff --git a/metrics/instance_segmentation_metrics.py b/metrics/instance_segmentation_metrics.py
index 83e97a7baceaabe791168b15575835d1359ad690..34c715d2f9f8b954f526eba0c7c3609dee07fb67 100644
--- a/metrics/instance_segmentation_metrics.py
+++ b/metrics/instance_segmentation_metrics.py
@@ -9,6 +9,8 @@ from sklearn.neighbors import KDTree
 logging.basicConfig(level=logging.INFO)
 
 class InstanceSegmentationMetrics:
+    GT_LABEL_NAME = 'treeID'  #GT_LABEL_NAME = 'StemID'
+    TARGET_LABEL_NAME = 'instance_nr'
     def __init__(
         self, 
         input_file_path, 
@@ -27,13 +29,16 @@ class InstanceSegmentationMetrics:
         self.input_las = laspy.read(self.input_file_path)
         self.instance_segmented_las = laspy.read(self.instance_segmented_file_path)
 
-        self.skip_flag = self.check_if_labels_exist()
+        self.skip_flag = self.check_if_labels_exist(
+            X_label=self.GT_LABEL_NAME,
+            Y_label=self.TARGET_LABEL_NAME
+            )
 
         if not self.skip_flag:
             # get labels from input las file
-            self.X_labels = self.input_las.treeID.astype(int) #TODO: generalize this to other labels
+            self.X_labels = self.input_las[self.GT_LABEL_NAME].astype(int) 
             # get labels from instance segmented las file
-            self.Y_labels = self.instance_segmented_las.instance_nr.astype(int) #TODO: generalize this to other labels
+            self.Y_labels = self.instance_segmented_las[self.TARGET_LABEL_NAME].astype(int) 
             # if self.remove_ground:
             #     # the labeling starts from 0, so we need to remove the ground
             #     self.Y_labels += 1
@@ -126,12 +131,12 @@ class InstanceSegmentationMetrics:
     # define a function that finds class in input_file with the most points
     def find_dominant_classes_in_gt(self, input_file):
         # get the unique labels
-        unique_labels = np.unique(input_file.treeID).astype(int)
+        unique_labels = np.unique(input_file[self.GT_LABEL_NAME]).astype(int)
         # create a dictionary
         tmp_dict = {}
         for label in unique_labels:
             # get the indices of input_file.treeID == label
-            ind_label = np.where(input_file.treeID == label)[0]
+            ind_label = np.where(input_file[self.GT_LABEL_NAME] == label)[0]
             # put the number of points to the tmp_dict
             tmp_dict[str(label)] = ind_label.shape[0]
         # sort tmp_dict by the number of points
@@ -331,7 +336,7 @@ class InstanceSegmentationMetrics:
         # compute tree level metrics
         if metric_dict:
             # get the number of trees in the ground truth
-            gt_trees = np.unique(self.input_las.treeID)
+            gt_trees = np.unique(self.input_las[self.GT_LABEL_NAME])
 
             # remove 0 from gt_trees
             gt_trees = gt_trees[gt_trees != 0]
diff --git a/metrics/instance_segmentation_metrics_in_folder.py b/metrics/instance_segmentation_metrics_in_folder.py
index 8ddd37be3a61472450ce7d5a872675c13457d7e9..ccc57c93702e3996b1c5281e5c8ae1d85d458ab8 100644
--- a/metrics/instance_segmentation_metrics_in_folder.py
+++ b/metrics/instance_segmentation_metrics_in_folder.py
@@ -9,6 +9,9 @@ from metrics.instance_segmentation_metrics import InstanceSegmentationMetrics
 from nibio_postprocessing.attach_labels_to_las_file import AttachLabelsToLasFile
 
 class InstanceSegmentationMetricsInFolder():
+    GT_LABEL_NAME = 'treeID'
+    TARGET_LABEL_NAME = 'instance_nr'
+
     def __init__(
         self,
         gt_las_folder_path,
@@ -161,8 +164,8 @@ class InstanceSegmentationMetricsInFolder():
                     gt_las_file_path,
                     target_las_file_path,
                     update_las_file_path = os.path.join(self.output_folder_path, gt_las_file_core_name + '.las'),
-                    gt_label_name='treeID',
-                    target_label_name='treeID',
+                    gt_label_name=self.GT_LABEL_NAME,
+                    target_label_name=self.GT_LABEL_NAME,
                     verbose=self.verbose
                 ).main()