diff --git a/sean_sem_seg/other_parameters.py b/sean_sem_seg/other_parameters.py
index 6b77d620302df7827253f897b6a1f9d7aec148d5..b1f86c14dea9935f119c208446e0cc5b176ae030 100644
--- a/sean_sem_seg/other_parameters.py
+++ b/sean_sem_seg/other_parameters.py
@@ -15,6 +15,8 @@ other_parameters = dict(
     vegetation_class=2,  # Don't change
     cwd_class=3,  # Don't change
     stem_class=4,  # Don't change
+    branch_class=5,  # Don't change
+    low_vegetation_class=6,  # Don't change
     grid_resolution=0.5,  # Resolution of the DTM.
     vegetation_coverage_resolution=0.2,
     num_neighbours=5,
diff --git a/sean_sem_seg/post_segmentation_script.py b/sean_sem_seg/post_segmentation_script.py
index cd80b496d694bbf28b40b161e197bea357eb6680..8a09ab5db9af1718e49c05e112d3aa4b0f06b1f2 100644
--- a/sean_sem_seg/post_segmentation_script.py
+++ b/sean_sem_seg/post_segmentation_script.py
@@ -26,6 +26,10 @@ from tools import load_file, save_file, subsample_point_cloud, get_heights_above
 from scipy.interpolate import griddata
 from fsct_exceptions import DataQualityError
 
+from other_parameters import other_parameters
+
+NUMBER_OF_CLASSES = other_parameters["num_classes"]
+
 warnings.filterwarnings("ignore", category=RuntimeWarning)
 
 
@@ -47,6 +51,8 @@ class PostProcessing:
         self.vegetation_class_label = parameters["vegetation_class"]
         self.cwd_class_label = parameters["cwd_class"]
         self.stem_class_label = parameters["stem_class"]
+        self.branch_class_label = parameters["branch_class"]
+        self.low_vegetation_class_label = parameters["low_vegetation_class"]
         print("Loading segmented point cloud...")
         self.point_cloud, self.headers_of_interest = load_file(
             self.output_dir + "segmented.las", headers_of_interest=["x", "y", "z", "red", "green", "blue", "label"]
@@ -128,6 +134,8 @@ class PostProcessing:
         self.point_cloud = get_heights_above_DTM(
             self.point_cloud, self.DTM
         )  # Add a height above DTM column to the point clouds.
+
+        # terrain points
         self.terrain_points = self.point_cloud[self.point_cloud[:, self.label_index] == self.terrain_class_label]
         self.terrain_points_rejected = np.vstack(
             (
@@ -148,6 +156,8 @@ class PostProcessing:
             headers_of_interest=self.headers_of_interest,
             silent=False,
         )
+
+        # stem points
         self.stem_points = self.point_cloud[self.point_cloud[:, self.label_index] == self.stem_class_label]
         self.terrain_points = np.vstack(
             (
@@ -169,6 +179,53 @@ class PostProcessing:
             silent=False,
         )
 
+
+        if NUMBER_OF_CLASSES == 6:
+            #branches
+            self.branch_points = self.point_cloud[self.point_cloud[:, self.label_index] == self.branch_class_label]
+            self.terrain_points = np.vstack(
+                (
+                    self.terrain_points,
+                    self.branch_points[
+                        np.logical_and(
+                            self.branch_points[:, -1] >= -above_and_below_DTM_trim_dist,
+                            self.branch_points[:, -1] <= above_and_below_DTM_trim_dist,
+                        )
+                    ],
+                )
+            )
+            self.branch_points_rejected = self.branch_points[self.branch_points[:, -1] <= above_and_below_DTM_trim_dist]
+            self.branch_points = self.branch_points[self.branch_points[:, -1] > above_and_below_DTM_trim_dist]
+            save_file(
+                self.output_dir + "branch_points.las",
+                self.branch_points,
+                headers_of_interest=self.headers_of_interest,
+                silent=False,
+            )
+
+            #low vegetation
+            self.low_vegetation_points = self.point_cloud[self.point_cloud[:, self.label_index] == self.low_vegetation_class_label]
+            low_vegetation_threshold = 0.1
+            self.terrain_points = np.vstack(
+                (
+                    self.terrain_points,
+                    self.low_vegetation_points[
+                        np.logical_and(
+                            self.low_vegetation_points[:, -1] >= -low_vegetation_threshold,
+                            self.low_vegetation_points[:, -1] <= low_vegetation_threshold,
+                        )
+                    ],
+                )
+            )
+            self.low_vegetation_points_rejected = self.low_vegetation_points[self.low_vegetation_points[:, -1] <= low_vegetation_threshold]
+            self.low_vegetation_points = self.low_vegetation_points[self.low_vegetation_points[:, -1] > low_vegetation_threshold]
+            save_file(
+                self.output_dir + "low_vegetation_points.las",
+                self.low_vegetation_points,
+                headers_of_interest=self.headers_of_interest,
+                silent=False,
+            )
+            
         self.vegetation_points = self.point_cloud[self.point_cloud[:, self.label_index] == self.vegetation_class_label]
         self.terrain_points = np.vstack(
             (
@@ -224,7 +281,19 @@ class PostProcessing:
         )
 
         self.terrain_points[:, self.label_index] = self.terrain_class_label
-        self.cleaned_pc = np.vstack((self.terrain_points, self.vegetation_points, self.cwd_points, self.stem_points))
+        
+        if NUMBER_OF_CLASSES == 6:
+            self.cleaned_pc = np.vstack((
+                self.terrain_points, 
+                self.vegetation_points,
+                self.cwd_points, 
+                self.stem_points, 
+                self.branch_points, 
+                self.low_vegetation_points
+                ))
+        else:
+            self.cleaned_pc = np.vstack((self.terrain_points, self.vegetation_points, self.cwd_points, self.stem_points))
+
         save_file(
             self.output_dir + "segmented_cleaned.las", self.cleaned_pc, headers_of_interest=self.headers_of_interest
         )