From 53cdb4ba950976281a6b58831f656467660de428 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Sun, 12 Feb 2023 11:23:37 +0100
Subject: [PATCH] in a process of updating seans model to work with 4 or 6
 classes

---
 run_bash_scripts/sem_seg_sean.sh         | 11 ++++++++---
 sean_sem_seg/inference.py                | 13 +++++++++----
 sean_sem_seg/other_parameters.py         |  1 +
 sean_sem_seg/post_segmentation_script.py |  2 ++
 4 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/run_bash_scripts/sem_seg_sean.sh b/run_bash_scripts/sem_seg_sean.sh
index 37327d2..33ac276 100755
--- a/run_bash_scripts/sem_seg_sean.sh
+++ b/run_bash_scripts/sem_seg_sean.sh
@@ -3,12 +3,12 @@
 ############################ parameters #################################################
 # General parameters
 CLEAR_INPUT_FOLDER=1  # 1: clear input folder, 0: not clear input folder
-CONDA_ENV="pdal-env" # conda environment for running the pipeline
+CONDA_ENV="pdal-env-1" # conda environment for running the pipeline
 
 # Parameters for the semetnic segmentation
 data_folder="" # path to the folder containing the data
-checkpoint_model_path="./fsct/model/model.pth"
-batch_size=5 # batch size for the inference
+checkpoint_model_path="./fsct/model/model.pth" # path to the checkpoint model (take our basic model as default)
+batch_size=10 # batch size for the inference
 tile_size=10 # tile size in meters
 min_density=75 # minimum density of points in a tile(used for removing small tiles)
 remove_small_tiles=0 # 1: remove small tiles, 0: not remove small tiles
@@ -45,6 +45,11 @@ echo "d: data_folder"
 echo "      The values of the parameters:"
 echo "data_folder: $data_folder"
 echo "remove_small_tiles: $remove_small_tiles"
+echo "checkpoint_model_path: $checkpoint_model_path"
+echo "batch_size: $batch_size"
+echo "tile_size: $tile_size"
+echo "min_density: $min_density"
+
 
 # Do the environment setup
 # check if PYTHONPATH is set to the current directory
diff --git a/sean_sem_seg/inference.py b/sean_sem_seg/inference.py
index 6c24502..0950023 100644
--- a/sean_sem_seg/inference.py
+++ b/sean_sem_seg/inference.py
@@ -19,6 +19,10 @@ from tools import load_file, save_file
 import shutil
 import sys
 
+from other_parameters import other_parameters
+
+NUM_CLASSES = other_parameters['num_classes']
+
 sys.setrecursionlimit(10**8)  # Can be necessary for dealing with large point clouds.
 
 
@@ -60,11 +64,12 @@ def choose_most_confident_label(point_cloud, original_point_cloud):
     )
     _, indices = neighbours.kneighbors(original_point_cloud[:, :3])
 
-    labels = np.zeros((original_point_cloud.shape[0], 5))
-    labels[:, :4] = np.median(point_cloud[indices][:, :, -4:], axis=1)
-    labels[:, 4] = np.argmax(labels[:, :4], axis=1)
+    labels = np.zeros((original_point_cloud.shape[0], NUM_CLASSES + 1))
+    labels[:, :NUM_CLASSES] = np.median(point_cloud[indices][:, :, -NUM_CLASSES:], axis=1)
+    labels[:, NUM_CLASSES] = np.argmax(labels[:, :NUM_CLASSES], axis=1)
 
     original_point_cloud = np.hstack((original_point_cloud, labels[:, 4:]))
+
     return original_point_cloud
 
 class SemanticSegmentation:
@@ -99,7 +104,7 @@ class SemanticSegmentation:
 
         test_loader = DataLoader(test_dataset, batch_size=self.parameters["batch_size"], shuffle=False, num_workers=0)
 
-        model = Net(num_classes=4).to(self.device)
+        model = Net(num_classes=NUM_CLASSES).to(self.device)
         if self.parameters["use_CPU_only"]:
             model.load_state_dict(
                 torch.load(
diff --git a/sean_sem_seg/other_parameters.py b/sean_sem_seg/other_parameters.py
index 7c99795..6b77d62 100644
--- a/sean_sem_seg/other_parameters.py
+++ b/sean_sem_seg/other_parameters.py
@@ -9,6 +9,7 @@ other_parameters = dict(
     box_overlap=[0.5, 0.5, 0.5],  # Overlap of the sliding box used for semantic segmentation.
     min_points_per_box=1000,  # Minimum number of points for input to the model. Too few points and it becomes near impossible to accurately label them (though assuming vegetation class is the safest bet here).
     max_points_per_box=20000,  # Maximum number of points for input to the model. The model may tolerate higher numbers if you decrease the batch size accordingly (to fit on the GPU), but this is not tested.
+    num_classes = 4,  # Number of classes in the model. Don't change this unless you are changing the model.
     noise_class=0,  # Don't change
     terrain_class=1,  # Don't change
     vegetation_class=2,  # Don't change
diff --git a/sean_sem_seg/post_segmentation_script.py b/sean_sem_seg/post_segmentation_script.py
index 380aa83..cd80b49 100644
--- a/sean_sem_seg/post_segmentation_script.py
+++ b/sean_sem_seg/post_segmentation_script.py
@@ -106,6 +106,8 @@ class PostProcessing:
         return grid_points
 
     def process_point_cloud(self):
+        print("Processing point cloud...")
+
         self.terrain_points = self.point_cloud[
             self.point_cloud[:, self.label_index] == self.terrain_class_label
         ]  # -2 is now the class label as we added the height above DTM column.
-- 
GitLab