From 0441820e57da3a92f7ad05f75c523c7770c9ed9f Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Thu, 16 Mar 2023 11:12:29 +0100 Subject: [PATCH] update of pl forest transformer --- forest_sem_seg_transformer_pl.py | 155 ++++++++++++++++++++++++++++--- 1 file changed, 143 insertions(+), 12 deletions(-) diff --git a/forest_sem_seg_transformer_pl.py b/forest_sem_seg_transformer_pl.py index 6d04356..35a4593 100644 --- a/forest_sem_seg_transformer_pl.py +++ b/forest_sem_seg_transformer_pl.py @@ -1,5 +1,7 @@ import argparse import importlib +import os +import shutil import hydra import numpy as np import omegaconf @@ -45,6 +47,18 @@ class ForestSemSegTransformer(pl.LightningModule): for label in self.seg_classes[cat]: self.seg_label_to_cat[label] = cat + self.results_dir = hydra.utils.to_absolute_path('results') + # create folder to save the results las files + if not os.path.exists(self.results_dir): + os.mkdir(self.results_dir) + + self.test_dataset = Dataset( + root=hydra.utils.to_absolute_path('data/forest_txt/validation_txt/'), + npoints=self.conf.num_point, + normal_channel=self.conf.normal, + normalize_point_cloud=False) + + def forward(self, data): return self.model(data) @@ -72,12 +86,79 @@ class ForestSemSegTransformer(pl.LightningModule): else: optimizer = torch.optim.SGD(self.model.parameters(), lr=self.conf.learning_rate, momentum=0.9) - # update learning rate - lr = max(self.conf.learning_rate * (self.conf.lr_decay ** (self.current_epoch // self.conf.step_size)), self.learning_rate_clip) - for param_group in optimizer.param_groups: - param_group['lr'] = lr + return optimizer + + def random_scale_point_cloud_pth(self, batch_data, scale_low=0.8, scale_high=1.25): + """Randomly scale the point cloud. Scale is per point cloud. + Input: + batch_data: (B, N, 3) tensor, original batch of point clouds + scale_low: float, lower bound of the random scale factor + scale_high: float, upper bound of the random scale factor + Return: + (B, N, 3) tensor, scaled batch of point clouds + """ + B, N, C = batch_data.size() + scales = torch.FloatTensor(B).uniform_(scale_low, scale_high) + scales = scales.view(B, 1, 1) # Reshape for broadcasting + batch_data *= scales.to(batch_data.device) + return batch_data + + def shift_point_cloud_pth(self, batch_data, shift_range=0.1): + """Randomly shift point cloud. Shift is per point cloud. + Input: + batch_data: (B, N, 3) tensor, original batch of point clouds + shift_range: float, maximum distance to shift each point + Return: + (B, N, 3) tensor, shifted batch of point clouds + """ + B, N, C = batch_data.size() + shifts = torch.FloatTensor(B, 1, 3).uniform_(-shift_range, shift_range) + shifts = shifts.to(batch_data.device) + batch_data += shifts + return batch_data + + def shift_point_cloud_np(self, batch_data, shift_range=0.1): + """ Randomly shift point cloud. Shift is per point cloud. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, shifted batch of point clouds + """ + # convert to numpy array + batch_data = batch_data.cpu().numpy() + + + B, N, C = batch_data.shape + shifts = np.random.uniform(-shift_range, shift_range, (B,3)) + for batch_index in range(B): + batch_data[batch_index,:,:] += shifts[batch_index,:] + + # convert back to torch tensor + batch_data = torch.from_numpy(batch_data).cuda() + + return batch_data + + def random_scale_point_cloud_np(self, batch_data, scale_low=0.8, scale_high=1.25): + """ Randomly scale the point cloud. Scale is per point cloud. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, scaled batch of point clouds + """ + # convert to numpy array + batch_data = batch_data.cpu().numpy() + + + B, N, C = batch_data.shape + scales = np.random.uniform(scale_low, scale_high, B) + for batch_index in range(B): + batch_data[batch_index,:,:] *= scales[batch_index] + + # convert back to torch tensor + batch_data = torch.from_numpy(batch_data).cuda() + return batch_data # define epoch se def on_train_epoch_start(self) -> None: @@ -89,17 +170,25 @@ class ForestSemSegTransformer(pl.LightningModule): self.model = self.model.train() self.acc_train_mean_correct = [] + + # update learning rate + lr = max(self.conf.learning_rate * (self.conf.lr_decay ** (self.current_epoch // self.conf.step_size)), self.learning_rate_clip) + for param_group in self.optimizers().param_groups: + param_group['lr'] = lr def on_training_epoch_end(self): self.log("train_acc", np.mean(self.acc_train_mean_correct), on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log("epoch", self.current_epoch, on_step=True, on_epoch=True, prog_bar=True, logger=True) + + def training_step(self, batch, batch_idx): - points, label = batch - points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3]) - points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) + # points[:, :, 0:3] = self.random_scale_point_cloud_np(points[:, :, 0:3]) + points[:, :, 0:3] = self.random_scale_point_cloud_pth(points[:, :, 0:3]) + # points[:, :, 0:3] = self.shift_point_cloud_np(points[:, :, 0:3]) + points[:, :, 0:3] = self.shift_point_cloud_pth(points[:, :, 0:3]) points = torch.Tensor(points) points, label = points.float().cuda(), label.long().cuda() @@ -107,6 +196,7 @@ class ForestSemSegTransformer(pl.LightningModule): [points, self.to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) + seg_pred = seg_pred.contiguous().view(-1, self.conf.num_part) target = label.view(-1, 1)[:, 0] pred_choice = seg_pred.data.max(1)[1] @@ -172,6 +262,31 @@ class ForestSemSegTransformer(pl.LightningModule): logits = cur_pred_val_logits[i, :, :] cur_pred_val[i, :] = np.argmax(logits[:, self.seg_classes[cat]], 1) + self.seg_classes[cat][0] + # get x,y,z coordinates of points + points = points.cpu().data.numpy() + points = points[:, :, 0:3] + points_pd = np.concatenate([points, np.expand_dims(cur_pred_val, axis=2)], axis=2) + points_gt = np.concatenate([points, np.expand_dims(target, axis=2)], axis=2) + # save points as text files in the results folder and preserve the same name as the original txt file + for i in range(cur_batch_size): + np.savetxt(os.path.join(self.results_dir, + self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1].replace('.txt', '_pred.txt') + ), + points_pd[i, :, :], fmt='%f %f %f %d') + + # copy original txt file to the results folder + shutil.copy(self.test_dataset.datapath[batch_idx * self.conf.batch_size + i], + os.path.join(self.results_dir, self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1] + )) + + # save ground truth labels as text files in the results folder and preserve the same name as the original txt file + np.savetxt(os.path.join( + self.results_dir, + self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1].replace('.txt', '_gt.txt') + ), + points_gt[i, :, :], fmt='%f %f %f %d') + + correct = np.sum(cur_pred_val == target) self.total_correct += correct self.total_seen += (cur_batch_size * NUM_POINT) @@ -196,8 +311,6 @@ class ForestSemSegTransformer(pl.LightningModule): return self.test_metrics - - class ForestDataset(pl.LightningDataModule): def __init__(self, conf): super().__init__() @@ -208,8 +321,18 @@ class ForestDataset(pl.LightningDataModule): self.train_dataset = None def setup(self, stage=None): - self.test_dataset = Dataset(root=self.test_dataset_path, npoints=self.conf.num_point, normal_channel=self.conf.normal, normalize_point_cloud=True) - self.train_dataset = Dataset(root=self.train_dataset_path, npoints=self.conf.num_point, normal_channel=self.conf.normal, normalize_point_cloud=True) + self.test_dataset = Dataset( + root=self.test_dataset_path, + npoints=self.conf.num_point, + normal_channel=self.conf.normal, + normalize_point_cloud=self.conf.normalize_point_cloud + ) + self.train_dataset = Dataset( + root=self.train_dataset_path, + npoints=self.conf.num_point, + normal_channel=self.conf.normal, + normalize_point_cloud=True + ) def train_dataloader(self): return torch.utils.data.DataLoader( @@ -232,10 +355,18 @@ class ForestDataset(pl.LightningDataModule): def main(args): omegaconf.OmegaConf.set_struct(args, False) + # add a parameter to args + trainer = pl.Trainer(gpus=1, max_epochs=args.epoch) + # args.normalize_point_cloud = True # trainer.fit(ForestSemSegTransformer(args), ForestDataset(args)) - trainer.validate(ForestSemSegTransformer(args), ForestDataset(args)) + args.normalize_point_cloud = True + trainer.validate( + ForestSemSegTransformer(args), + ForestDataset(args), + ckpt_path='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/log/partseg/Hengshuang/lightning_logs/version_71/checkpoints/epoch=9-step=30.ckpt' + ) if __name__ == '__main__': main() -- GitLab