diff --git a/.gitignore b/.gitignore
index 25dc176458dafda36b4e752d5a6b532d9ddaaa91..9e04dff31ea8dfaf5a76834450130c223cef1001 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,4 +14,5 @@ wandb/
 .vscode/
 sample_data/segmented_point_clouds/*
 *.dat
-*.json
\ No newline at end of file
+*.json
+*.png
\ No newline at end of file
diff --git a/metrics/metrics_sem_seg.py b/metrics/metrics_sem_seg.py
index b93e7e8f1bb47e3270d2112b710547337d8a0acf..f623a9e077c2e52c10a1eb6b8a97ee5c37566469 100644
--- a/metrics/metrics_sem_seg.py
+++ b/metrics/metrics_sem_seg.py
@@ -1,4 +1,5 @@
 import argparse
+import json
 import os
 from joblib import Parallel, delayed
 import laspy
@@ -97,7 +98,11 @@ class MetricSemSeg:
             labels = labels[labels != 0]
             labels = labels - 1
         else:
-            labels = point_cloud.label
+            # if min label is not 0, then subtract 1 from all labels
+            if np.min(point_cloud.label) != 0:
+                labels = point_cloud.label - 1
+            else:
+                labels = point_cloud.label
 
         # convert to int
         labels = labels.astype(int)
@@ -176,12 +181,15 @@ class MetricSemSeg:
         if self.verbose:
             print("conf_matrix.shape: ", conf_matrix.shape)
 
+        # get the class names
+        if conf_matrix.shape[0] == 3:
+            class_names = ['terrain', 'vegetation', 'stem']
+        elif conf_matrix.shape[0] == 4:
+            class_names = ['terrain', 'vegetation', 'CWD', 'stem']
+
         # get picture of the confusion matrix
         if self.plot_confusion_matrix:
-            if conf_matrix.shape[0] == 3:
-                class_names = ['terrain', 'vegetation', 'stem']
-            elif conf_matrix.shape[0] == 4:
-                class_names = ['terrain', 'vegetation', 'CWD', 'stem']
+ 
             disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=class_names)
             disp.plot()
             plt.savefig(file_name_original + '_confusion_matrix.png')
@@ -194,14 +202,18 @@ class MetricSemSeg:
         f1 = f1_score(labels_original_closest, labels_predicted, average='weighted')
         f1 = np.round(f1, decimals=3)
 
-        # compute precision, recall and f1 per class
-        precision_per_class = precision_score(labels_original_closest, labels_predicted, average=None)  
-        precision_per_class = np.round(precision_per_class, decimals=3)
-        recall_per_class = recall_score(labels_original_closest, labels_predicted, average=None)
-        recall_per_class = np.round(recall_per_class, decimals=3)
-        f1_per_class = f1_score(labels_original_closest, labels_predicted, average=None)
-        f1_per_class = np.round(f1_per_class, decimals=3)
-
+        # compute precision, recall and f1 per class per class_name
+        precision_per_class = {}
+        recall_per_class = {}
+        f1_per_class = {}
+        for name in class_names:
+            precision_per_class[name] = precision_score(labels_original_closest, labels_predicted, labels=[class_names.index(name)], average='weighted')
+            precision_per_class[name] = np.round(precision_per_class[name], decimals=3)
+            recall_per_class[name] = recall_score(labels_original_closest, labels_predicted, labels=[class_names.index(name)], average='weighted')
+            recall_per_class[name] = np.round(recall_per_class[name], decimals=3)
+            f1_per_class[name] = f1_score(labels_original_closest, labels_predicted, labels=[class_names.index(name)], average='weighted')
+            f1_per_class[name] = np.round(f1_per_class[name], decimals=3)
+        
         # put all the results in a dictionary
         results = {
             'confusion_matrix': conf_matrix,
@@ -257,12 +269,17 @@ class MetricSemSeg:
         # compute the mean of the confusion matrix
         conf_matrix_mean = np.mean(conf_matrix_list, axis=0)
 
+        # two decimal places
+        conf_matrix_mean = np.round(conf_matrix_mean, decimals=2)
+
+        if conf_matrix_mean.shape[0] == 3:
+            class_names = ['terrain', 'vegetation', 'stem']
+        elif conf_matrix_mean.shape[0] == 4:
+            class_names = ['terrain', 'vegetation', 'CWD', 'stem']
+
         # save the confusion matrix
         if self.plot_confusion_matrix:
-            if conf_matrix_mean.shape[0] == 3:
-                class_names = ['terrain', 'vegetation', 'stem']
-            elif conf_matrix_mean.shape[0] == 4:
-                class_names = ['terrain', 'vegetation', 'CWD', 'stem']
+
             disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix_mean, display_labels=class_names)
             disp.plot()
             plt.savefig('confusion_matrix_mean.png')
@@ -272,47 +289,128 @@ class MetricSemSeg:
         recall_mean = np.mean(recall_list)
         f1_mean = np.mean(f1_list)
 
+        # two decimal places
+        precision_mean = np.round(precision_mean, decimals=2)
+        recall_mean = np.round(recall_mean, decimals=2)
+        f1_mean = np.round(f1_mean, decimals=2)
+
         # compute the mean of the precision, recall and f1 score per class
         # compute separately for items which have 3 and 4 classes
-        # create a separate list for item with 3 classes
-        precision_per_class_list_3 = []
-        recall_per_class_list_3 = []
-        f1_per_class_list_3 = []
+        # create a separate list for item with 3 classes defined class_names
+        precision_per_class_list_3 = {}
+        recall_per_class_list_3 = {}
+        f1_per_class_list_3 = {}
+        for name in class_names:
+            precision_per_class_list_3[name] = []
+            recall_per_class_list_3[name] = []
+            f1_per_class_list_3[name] = []
+
         # create a separate list for item with 4 classes
-        precision_per_class_list_4 = []
-        recall_per_class_list_4 = []
-        f1_per_class_list_4 = []
+        precision_per_class_list_4 = {}
+        recall_per_class_list_4 = {}
+        f1_per_class_list_4 = {}
+        for name in class_names:
+            precision_per_class_list_4[name] = []
+            recall_per_class_list_4[name] = []
+            f1_per_class_list_4[name] = []
+
         # loop over all items
-        for i in range(len(precision_per_class_list)):
-            # check if the item has 3 classes
-            if len(precision_per_class_list[i]) == 3:
-                precision_per_class_list_3.append(precision_per_class_list[i])
-                recall_per_class_list_3.append(recall_per_class_list[i])
-                f1_per_class_list_3.append(f1_per_class_list[i])
-            # check if the item has 4 classes
-            elif len(precision_per_class_list[i]) == 4:
-                precision_per_class_list_4.append(precision_per_class_list[i])
-                recall_per_class_list_4.append(recall_per_class_list[i])
-                f1_per_class_list_4.append(f1_per_class_list[i])
+        for precision_per_class, recall_per_class, f1_per_class in zip(precision_per_class_list, recall_per_class_list, f1_per_class_list):
+            # check if the item has 3 or 4 classes
+            if len(precision_per_class) == 3:
+                for name in class_names:
+                    precision_per_class_list_3[name].append(precision_per_class[name])
+                    recall_per_class_list_3[name].append(recall_per_class[name])
+                    f1_per_class_list_3[name].append(f1_per_class[name])
+            elif len(precision_per_class) == 4:
+                for name in class_names:
+                    precision_per_class_list_4[name].append(precision_per_class[name])
+                    recall_per_class_list_4[name].append(recall_per_class[name])
+                    f1_per_class_list_4[name].append(f1_per_class[name])
             else:
-                print("ERROR: the number of classes is neither 3 nor 4")
-                exit()
-
+                raise ValueError("The number of classes is not 3 or 4.")
 
         # compute the mean of the precision, recall and f1 score per class for items with 3 classes
-        precision_per_class_mean_3 = np.mean(precision_per_class_list_3, axis=0)
-        recall_per_class_mean_3 = np.mean(recall_per_class_list_3, axis=0)
-        f1_per_class_mean_3 = np.mean(f1_per_class_list_3, axis=0)
+        precision_per_class_list_mean_3 = {} 
+        recall_per_class_list_mean_3 = {}
+        f1_per_class_list_mean_3 = {}
+        for name in class_names:
+            precision_per_class_list_mean_3[name] = np.mean(precision_per_class_list_3[name])
+            recall_per_class_list_mean_3[name] = np.mean(recall_per_class_list_3[name])
+            f1_per_class_list_mean_3[name] = np.mean(f1_per_class_list_3[name])
 
         # compute the mean of the precision, recall and f1 score per class for items with 4 classes
-        precision_per_class_mean_4 = np.mean(precision_per_class_list_4, axis=0)
-        recall_per_class_mean_4 = np.mean(recall_per_class_list_4, axis=0)
-        f1_per_class_mean_4 = np.mean(f1_per_class_list_4, axis=0)
+        precision_per_class_list_mean_4 = {}
+        recall_per_class_list_mean_4 = {}
+        f1_per_class_list_mean_4 = {}
+        for name in class_names:
+            precision_per_class_list_mean_4[name] = np.mean(precision_per_class_list_4[name])
+            recall_per_class_list_mean_4[name] = np.mean(recall_per_class_list_4[name])
+            f1_per_class_list_mean_4[name] = np.mean(f1_per_class_list_4[name])
 
         # compute the mean of the precision, recall and f1 score per class
-        precision_per_class_mean = np.mean(np.concatenate((precision_per_class_mean_3, precision_per_class_mean_4), axis=0), axis=0)
-        recall_per_class_mean = np.mean(np.concatenate((recall_per_class_mean_3, recall_per_class_mean_4), axis=0), axis=0)
-        f1_per_class_mean = np.mean(np.concatenate((f1_per_class_mean_3, f1_per_class_mean_4), axis=0), axis=0)
+        precision_per_class_mean = {}
+        recall_per_class_mean = {}
+        f1_per_class_mean = {}
+
+        # check if nan values are present in the list of mean values, if yes, replace them with 0
+        for name in class_names:
+            if np.isnan(precision_per_class_list_mean_3[name]):
+                precision_per_class_list_mean_3[name] = 0
+            if np.isnan(recall_per_class_list_mean_3[name]):
+                recall_per_class_list_mean_3[name] = 0
+            if np.isnan(f1_per_class_list_mean_3[name]):
+                f1_per_class_list_mean_3[name] = 0
+            if np.isnan(precision_per_class_list_mean_4[name]):
+                precision_per_class_list_mean_4[name] = 0
+            if np.isnan(recall_per_class_list_mean_4[name]):
+                recall_per_class_list_mean_4[name] = 0
+            if np.isnan(f1_per_class_list_mean_4[name]):
+                f1_per_class_list_mean_4[name] = 0
+
+        for name in class_names:
+            # if both items are different from 0, compute the mean
+            if precision_per_class_list_mean_3[name] != 0 and precision_per_class_list_mean_4[name] != 0:
+                precision_per_class_mean[name] = (precision_per_class_list_mean_3[name] + precision_per_class_list_mean_4[name]) / 2
+            # if one of the items is 0, take the other one
+            elif precision_per_class_list_mean_3[name] == 0:
+                precision_per_class_mean[name] = precision_per_class_list_mean_4[name]
+            elif precision_per_class_list_mean_4[name] == 0:
+                precision_per_class_mean[name] = precision_per_class_list_mean_3[name]
+            # if both items are 0, set the mean to 0
+            elif precision_per_class_list_mean_3[name] == 0 and precision_per_class_list_mean_4[name] == 0:
+                precision_per_class_mean[name] = 0
+
+            # if both items are different from 0, compute the mean
+            if recall_per_class_list_mean_3[name] != 0 and recall_per_class_list_mean_4[name] != 0:
+                recall_per_class_mean[name] = (recall_per_class_list_mean_3[name] + recall_per_class_list_mean_4[name]) / 2
+            # if one of the items is 0, take the other one
+            elif recall_per_class_list_mean_3[name] == 0:
+                recall_per_class_mean[name] = recall_per_class_list_mean_4[name]
+            elif recall_per_class_list_mean_4[name] == 0:
+                recall_per_class_mean[name] = recall_per_class_list_mean_3[name]
+            # if both items are 0, set the mean to 0
+            elif recall_per_class_list_mean_3[name] == 0 and recall_per_class_list_mean_4[name] == 0:
+                recall_per_class_mean[name] = 0
+
+            # if both items are different from 0, compute the mean
+            if f1_per_class_list_mean_3[name] != 0 and f1_per_class_list_mean_4[name] != 0:
+                f1_per_class_mean[name] = (f1_per_class_list_mean_3[name] + f1_per_class_list_mean_4[name]) / 2
+            # if one of the items is 0, take the other one
+            elif f1_per_class_list_mean_3[name] == 0:
+                f1_per_class_mean[name] = f1_per_class_list_mean_4[name]
+            elif f1_per_class_list_mean_4[name] == 0:
+                f1_per_class_mean[name] = f1_per_class_list_mean_3[name]
+            # if both items are 0, set the mean to 0
+            elif f1_per_class_list_mean_3[name] == 0 and f1_per_class_list_mean_4[name] == 0:
+                f1_per_class_mean[name] = 0
+
+
+        # reduct the number of decimal places to 2
+        for name in class_names:
+            precision_per_class_mean[name] = round(precision_per_class_mean[name], 2)
+            recall_per_class_mean[name] = round(recall_per_class_mean[name], 2)
+            f1_per_class_mean[name] = round(f1_per_class_mean[name], 2)
 
         # put all the results in a dictionary
         results = {
@@ -325,17 +423,22 @@ class MetricSemSeg:
             'f1_per_class': f1_per_class_mean
         }
 
+        if self.verbose:
+            # save results to a json file in self.gt_folder
+            with open(os.path.join(self.gt_folder, 'results.json'), 'w') as f:
+                # convert the numpy arrays to lists
+                results['confusion_matrix'] = results['confusion_matrix'].tolist()
+                json.dump(results, f, indent=4)
+          
+
+            print("The results are saved to the file: ", os.path.join(self.gt_folder, 'results.json'))
+
         return results
       
     def main(self):
-        # get the metrics for all point clouds
-        if self.verbose:
-            print("get the metrics for all point clouds")
-
+        # get the metrics for all files
         results = self.get_metrics_for_all_files()
-
-        if self.verbose:
-            print("results: ", results)
+        print("results: ", results)
 
         return results