From 0441820e57da3a92f7ad05f75c523c7770c9ed9f Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Thu, 16 Mar 2023 11:12:29 +0100
Subject: [PATCH] update of pl forest transformer

---
 forest_sem_seg_transformer_pl.py | 155 ++++++++++++++++++++++++++++---
 1 file changed, 143 insertions(+), 12 deletions(-)

diff --git a/forest_sem_seg_transformer_pl.py b/forest_sem_seg_transformer_pl.py
index 6d04356..35a4593 100644
--- a/forest_sem_seg_transformer_pl.py
+++ b/forest_sem_seg_transformer_pl.py
@@ -1,5 +1,7 @@
 import argparse
 import importlib
+import os
+import shutil
 import hydra
 import numpy as np
 import omegaconf
@@ -45,6 +47,18 @@ class ForestSemSegTransformer(pl.LightningModule):
             for label in self.seg_classes[cat]:
                 self.seg_label_to_cat[label] = cat
 
+        self.results_dir = hydra.utils.to_absolute_path('results')
+        # create folder to save the results las files
+        if not os.path.exists(self.results_dir):
+            os.mkdir(self.results_dir)
+
+        self.test_dataset = Dataset(
+            root=hydra.utils.to_absolute_path('data/forest_txt/validation_txt/'), 
+            npoints=self.conf.num_point, 
+            normal_channel=self.conf.normal,
+            normalize_point_cloud=False)
+
+
 
     def forward(self, data):
         return self.model(data)
@@ -72,12 +86,79 @@ class ForestSemSegTransformer(pl.LightningModule):
         else:
             optimizer = torch.optim.SGD(self.model.parameters(), lr=self.conf.learning_rate, momentum=0.9)
 
-        # update learning rate
-        lr = max(self.conf.learning_rate * (self.conf.lr_decay ** (self.current_epoch // self.conf.step_size)), self.learning_rate_clip)
-        for param_group in optimizer.param_groups:
-            param_group['lr'] = lr
+  
         
         return optimizer
+
+    def random_scale_point_cloud_pth(self, batch_data, scale_low=0.8, scale_high=1.25):
+        """Randomly scale the point cloud. Scale is per point cloud.
+        Input:
+            batch_data: (B, N, 3) tensor, original batch of point clouds
+            scale_low: float, lower bound of the random scale factor
+            scale_high: float, upper bound of the random scale factor
+        Return:
+            (B, N, 3) tensor, scaled batch of point clouds
+        """
+        B, N, C = batch_data.size()
+        scales = torch.FloatTensor(B).uniform_(scale_low, scale_high)
+        scales = scales.view(B, 1, 1)  # Reshape for broadcasting
+        batch_data *= scales.to(batch_data.device)
+        return batch_data
+    
+    def shift_point_cloud_pth(self, batch_data, shift_range=0.1):
+        """Randomly shift point cloud. Shift is per point cloud.
+        Input:
+            batch_data: (B, N, 3) tensor, original batch of point clouds
+            shift_range: float, maximum distance to shift each point
+        Return:
+            (B, N, 3) tensor, shifted batch of point clouds
+        """
+        B, N, C = batch_data.size()
+        shifts = torch.FloatTensor(B, 1, 3).uniform_(-shift_range, shift_range)
+        shifts = shifts.to(batch_data.device)
+        batch_data += shifts
+        return batch_data
+    
+    def shift_point_cloud_np(self, batch_data, shift_range=0.1):
+        """ Randomly shift point cloud. Shift is per point cloud.
+            Input:
+            BxNx3 array, original batch of point clouds
+            Return:
+            BxNx3 array, shifted batch of point clouds
+        """
+        # convert to numpy array
+        batch_data = batch_data.cpu().numpy()
+
+
+        B, N, C = batch_data.shape
+        shifts = np.random.uniform(-shift_range, shift_range, (B,3))
+        for batch_index in range(B):
+            batch_data[batch_index,:,:] += shifts[batch_index,:]
+
+        # convert back to torch tensor
+        batch_data = torch.from_numpy(batch_data).cuda()
+        
+        return batch_data
+
+    def random_scale_point_cloud_np(self, batch_data, scale_low=0.8, scale_high=1.25):
+        """ Randomly scale the point cloud. Scale is per point cloud.
+            Input:
+                BxNx3 array, original batch of point clouds
+            Return:
+                BxNx3 array, scaled batch of point clouds
+        """
+        # convert to numpy array
+        batch_data = batch_data.cpu().numpy()
+
+
+        B, N, C = batch_data.shape
+        scales = np.random.uniform(scale_low, scale_high, B)
+        for batch_index in range(B):
+            batch_data[batch_index,:,:] *= scales[batch_index]
+
+        # convert back to torch tensor
+        batch_data = torch.from_numpy(batch_data).cuda()
+        return batch_data
     
     # define epoch se
     def on_train_epoch_start(self) -> None:
@@ -89,17 +170,25 @@ class ForestSemSegTransformer(pl.LightningModule):
         self.model = self.model.train()
 
         self.acc_train_mean_correct = []
+
+        # update learning rate
+        lr = max(self.conf.learning_rate * (self.conf.lr_decay ** (self.current_epoch // self.conf.step_size)), self.learning_rate_clip)
+        for param_group in self.optimizers().param_groups:
+            param_group['lr'] = lr
        
     def on_training_epoch_end(self):
         self.log("train_acc", np.mean(self.acc_train_mean_correct), on_step=True, on_epoch=True, prog_bar=True, logger=True)
         self.log("epoch", self.current_epoch, on_step=True, on_epoch=True, prog_bar=True, logger=True)
 
+
+
     def training_step(self, batch, batch_idx):
- 
 
         points, label = batch
-        points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
-        points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
+        # points[:, :, 0:3] = self.random_scale_point_cloud_np(points[:, :, 0:3])
+        points[:, :, 0:3] = self.random_scale_point_cloud_pth(points[:, :, 0:3])
+        # points[:, :, 0:3] = self.shift_point_cloud_np(points[:, :, 0:3])
+        points[:, :, 0:3] = self.shift_point_cloud_pth(points[:, :, 0:3])
         points = torch.Tensor(points)
 
         points, label = points.float().cuda(), label.long().cuda()
@@ -107,6 +196,7 @@ class ForestSemSegTransformer(pl.LightningModule):
             [points, self.to_categorical(torch.ones((points.shape[0], 1), 
                                                     dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
         
+        seg_pred = seg_pred.contiguous().view(-1, self.conf.num_part)
         target = label.view(-1, 1)[:, 0]
         pred_choice = seg_pred.data.max(1)[1]
 
@@ -172,6 +262,31 @@ class ForestSemSegTransformer(pl.LightningModule):
                 logits = cur_pred_val_logits[i, :, :]
                 cur_pred_val[i, :] = np.argmax(logits[:, self.seg_classes[cat]], 1) + self.seg_classes[cat][0]
 
+             # get x,y,z coordinates of points
+            points = points.cpu().data.numpy()
+            points = points[:, :, 0:3]
+            points_pd = np.concatenate([points, np.expand_dims(cur_pred_val, axis=2)], axis=2)
+            points_gt = np.concatenate([points, np.expand_dims(target, axis=2)], axis=2)
+            # save points as text files in the results folder and preserve the same name as the original txt file
+            for i in range(cur_batch_size):
+                np.savetxt(os.path.join(self.results_dir,
+                    self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1].replace('.txt', '_pred.txt')
+                    ),
+                           points_pd[i, :, :], fmt='%f %f %f %d')
+                
+                # copy original txt file to the results folder
+                shutil.copy(self.test_dataset.datapath[batch_idx * self.conf.batch_size + i], 
+                    os.path.join(self.results_dir, self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1]
+                    ))
+                
+                # save ground truth labels as text files in the results folder and preserve the same name as the original txt file
+                np.savetxt(os.path.join(
+                    self.results_dir,
+                    self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1].replace('.txt', '_gt.txt')
+                    ),
+                            points_gt[i, :, :], fmt='%f %f %f %d')
+
+
             correct = np.sum(cur_pred_val == target)
             self.total_correct += correct
             self.total_seen += (cur_batch_size * NUM_POINT)
@@ -196,8 +311,6 @@ class ForestSemSegTransformer(pl.LightningModule):
 
         return self.test_metrics
 
-
-
 class ForestDataset(pl.LightningDataModule):
     def __init__(self, conf):
         super().__init__()
@@ -208,8 +321,18 @@ class ForestDataset(pl.LightningDataModule):
         self.train_dataset = None
 
     def setup(self, stage=None):
-        self.test_dataset = Dataset(root=self.test_dataset_path, npoints=self.conf.num_point, normal_channel=self.conf.normal, normalize_point_cloud=True)
-        self.train_dataset = Dataset(root=self.train_dataset_path, npoints=self.conf.num_point, normal_channel=self.conf.normal, normalize_point_cloud=True)
+        self.test_dataset = Dataset(
+            root=self.test_dataset_path, 
+            npoints=self.conf.num_point, 
+            normal_channel=self.conf.normal, 
+            normalize_point_cloud=self.conf.normalize_point_cloud
+            )
+        self.train_dataset = Dataset(
+            root=self.train_dataset_path, 
+            npoints=self.conf.num_point, 
+            normal_channel=self.conf.normal, 
+            normalize_point_cloud=True
+            )
 
     def train_dataloader(self):
         return torch.utils.data.DataLoader(
@@ -232,10 +355,18 @@ class ForestDataset(pl.LightningDataModule):
 def main(args):
 
     omegaconf.OmegaConf.set_struct(args, False)
+    # add a parameter to args
+  
 
     trainer = pl.Trainer(gpus=1, max_epochs=args.epoch)
+    # args.normalize_point_cloud = True
     # trainer.fit(ForestSemSegTransformer(args), ForestDataset(args))
-    trainer.validate(ForestSemSegTransformer(args), ForestDataset(args))
+    args.normalize_point_cloud = True
+    trainer.validate(
+        ForestSemSegTransformer(args), 
+        ForestDataset(args), 
+        ckpt_path='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/log/partseg/Hengshuang/lightning_logs/version_71/checkpoints/epoch=9-step=30.ckpt'
+        )
 
 if __name__ == '__main__':
     main()
-- 
GitLab