Skip to content
Snippets Groups Projects
Commit 0441820e authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

update of pl forest transformer

parent 8338af5b
Branches
No related tags found
No related merge requests found
import argparse import argparse
import importlib import importlib
import os
import shutil
import hydra import hydra
import numpy as np import numpy as np
import omegaconf import omegaconf
...@@ -45,6 +47,18 @@ class ForestSemSegTransformer(pl.LightningModule): ...@@ -45,6 +47,18 @@ class ForestSemSegTransformer(pl.LightningModule):
for label in self.seg_classes[cat]: for label in self.seg_classes[cat]:
self.seg_label_to_cat[label] = cat self.seg_label_to_cat[label] = cat
self.results_dir = hydra.utils.to_absolute_path('results')
# create folder to save the results las files
if not os.path.exists(self.results_dir):
os.mkdir(self.results_dir)
self.test_dataset = Dataset(
root=hydra.utils.to_absolute_path('data/forest_txt/validation_txt/'),
npoints=self.conf.num_point,
normal_channel=self.conf.normal,
normalize_point_cloud=False)
def forward(self, data): def forward(self, data):
return self.model(data) return self.model(data)
...@@ -72,12 +86,79 @@ class ForestSemSegTransformer(pl.LightningModule): ...@@ -72,12 +86,79 @@ class ForestSemSegTransformer(pl.LightningModule):
else: else:
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.conf.learning_rate, momentum=0.9) optimizer = torch.optim.SGD(self.model.parameters(), lr=self.conf.learning_rate, momentum=0.9)
# update learning rate
lr = max(self.conf.learning_rate * (self.conf.lr_decay ** (self.current_epoch // self.conf.step_size)), self.learning_rate_clip)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return optimizer return optimizer
def random_scale_point_cloud_pth(self, batch_data, scale_low=0.8, scale_high=1.25):
"""Randomly scale the point cloud. Scale is per point cloud.
Input:
batch_data: (B, N, 3) tensor, original batch of point clouds
scale_low: float, lower bound of the random scale factor
scale_high: float, upper bound of the random scale factor
Return:
(B, N, 3) tensor, scaled batch of point clouds
"""
B, N, C = batch_data.size()
scales = torch.FloatTensor(B).uniform_(scale_low, scale_high)
scales = scales.view(B, 1, 1) # Reshape for broadcasting
batch_data *= scales.to(batch_data.device)
return batch_data
def shift_point_cloud_pth(self, batch_data, shift_range=0.1):
"""Randomly shift point cloud. Shift is per point cloud.
Input:
batch_data: (B, N, 3) tensor, original batch of point clouds
shift_range: float, maximum distance to shift each point
Return:
(B, N, 3) tensor, shifted batch of point clouds
"""
B, N, C = batch_data.size()
shifts = torch.FloatTensor(B, 1, 3).uniform_(-shift_range, shift_range)
shifts = shifts.to(batch_data.device)
batch_data += shifts
return batch_data
def shift_point_cloud_np(self, batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
# convert to numpy array
batch_data = batch_data.cpu().numpy()
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
# convert back to torch tensor
batch_data = torch.from_numpy(batch_data).cuda()
return batch_data
def random_scale_point_cloud_np(self, batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
# convert to numpy array
batch_data = batch_data.cpu().numpy()
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
# convert back to torch tensor
batch_data = torch.from_numpy(batch_data).cuda()
return batch_data
# define epoch se # define epoch se
def on_train_epoch_start(self) -> None: def on_train_epoch_start(self) -> None:
...@@ -89,17 +170,25 @@ class ForestSemSegTransformer(pl.LightningModule): ...@@ -89,17 +170,25 @@ class ForestSemSegTransformer(pl.LightningModule):
self.model = self.model.train() self.model = self.model.train()
self.acc_train_mean_correct = [] self.acc_train_mean_correct = []
# update learning rate
lr = max(self.conf.learning_rate * (self.conf.lr_decay ** (self.current_epoch // self.conf.step_size)), self.learning_rate_clip)
for param_group in self.optimizers().param_groups:
param_group['lr'] = lr
def on_training_epoch_end(self): def on_training_epoch_end(self):
self.log("train_acc", np.mean(self.acc_train_mean_correct), on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log("train_acc", np.mean(self.acc_train_mean_correct), on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log("epoch", self.current_epoch, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.log("epoch", self.current_epoch, on_step=True, on_epoch=True, prog_bar=True, logger=True)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
points, label = batch points, label = batch
points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3]) # points[:, :, 0:3] = self.random_scale_point_cloud_np(points[:, :, 0:3])
points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) points[:, :, 0:3] = self.random_scale_point_cloud_pth(points[:, :, 0:3])
# points[:, :, 0:3] = self.shift_point_cloud_np(points[:, :, 0:3])
points[:, :, 0:3] = self.shift_point_cloud_pth(points[:, :, 0:3])
points = torch.Tensor(points) points = torch.Tensor(points)
points, label = points.float().cuda(), label.long().cuda() points, label = points.float().cuda(), label.long().cuda()
...@@ -107,6 +196,7 @@ class ForestSemSegTransformer(pl.LightningModule): ...@@ -107,6 +196,7 @@ class ForestSemSegTransformer(pl.LightningModule):
[points, self.to_categorical(torch.ones((points.shape[0], 1), [points, self.to_categorical(torch.ones((points.shape[0], 1),
dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
seg_pred = seg_pred.contiguous().view(-1, self.conf.num_part)
target = label.view(-1, 1)[:, 0] target = label.view(-1, 1)[:, 0]
pred_choice = seg_pred.data.max(1)[1] pred_choice = seg_pred.data.max(1)[1]
...@@ -172,6 +262,31 @@ class ForestSemSegTransformer(pl.LightningModule): ...@@ -172,6 +262,31 @@ class ForestSemSegTransformer(pl.LightningModule):
logits = cur_pred_val_logits[i, :, :] logits = cur_pred_val_logits[i, :, :]
cur_pred_val[i, :] = np.argmax(logits[:, self.seg_classes[cat]], 1) + self.seg_classes[cat][0] cur_pred_val[i, :] = np.argmax(logits[:, self.seg_classes[cat]], 1) + self.seg_classes[cat][0]
# get x,y,z coordinates of points
points = points.cpu().data.numpy()
points = points[:, :, 0:3]
points_pd = np.concatenate([points, np.expand_dims(cur_pred_val, axis=2)], axis=2)
points_gt = np.concatenate([points, np.expand_dims(target, axis=2)], axis=2)
# save points as text files in the results folder and preserve the same name as the original txt file
for i in range(cur_batch_size):
np.savetxt(os.path.join(self.results_dir,
self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1].replace('.txt', '_pred.txt')
),
points_pd[i, :, :], fmt='%f %f %f %d')
# copy original txt file to the results folder
shutil.copy(self.test_dataset.datapath[batch_idx * self.conf.batch_size + i],
os.path.join(self.results_dir, self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1]
))
# save ground truth labels as text files in the results folder and preserve the same name as the original txt file
np.savetxt(os.path.join(
self.results_dir,
self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1].replace('.txt', '_gt.txt')
),
points_gt[i, :, :], fmt='%f %f %f %d')
correct = np.sum(cur_pred_val == target) correct = np.sum(cur_pred_val == target)
self.total_correct += correct self.total_correct += correct
self.total_seen += (cur_batch_size * NUM_POINT) self.total_seen += (cur_batch_size * NUM_POINT)
...@@ -196,8 +311,6 @@ class ForestSemSegTransformer(pl.LightningModule): ...@@ -196,8 +311,6 @@ class ForestSemSegTransformer(pl.LightningModule):
return self.test_metrics return self.test_metrics
class ForestDataset(pl.LightningDataModule): class ForestDataset(pl.LightningDataModule):
def __init__(self, conf): def __init__(self, conf):
super().__init__() super().__init__()
...@@ -208,8 +321,18 @@ class ForestDataset(pl.LightningDataModule): ...@@ -208,8 +321,18 @@ class ForestDataset(pl.LightningDataModule):
self.train_dataset = None self.train_dataset = None
def setup(self, stage=None): def setup(self, stage=None):
self.test_dataset = Dataset(root=self.test_dataset_path, npoints=self.conf.num_point, normal_channel=self.conf.normal, normalize_point_cloud=True) self.test_dataset = Dataset(
self.train_dataset = Dataset(root=self.train_dataset_path, npoints=self.conf.num_point, normal_channel=self.conf.normal, normalize_point_cloud=True) root=self.test_dataset_path,
npoints=self.conf.num_point,
normal_channel=self.conf.normal,
normalize_point_cloud=self.conf.normalize_point_cloud
)
self.train_dataset = Dataset(
root=self.train_dataset_path,
npoints=self.conf.num_point,
normal_channel=self.conf.normal,
normalize_point_cloud=True
)
def train_dataloader(self): def train_dataloader(self):
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
...@@ -232,10 +355,18 @@ class ForestDataset(pl.LightningDataModule): ...@@ -232,10 +355,18 @@ class ForestDataset(pl.LightningDataModule):
def main(args): def main(args):
omegaconf.OmegaConf.set_struct(args, False) omegaconf.OmegaConf.set_struct(args, False)
# add a parameter to args
trainer = pl.Trainer(gpus=1, max_epochs=args.epoch) trainer = pl.Trainer(gpus=1, max_epochs=args.epoch)
# args.normalize_point_cloud = True
# trainer.fit(ForestSemSegTransformer(args), ForestDataset(args)) # trainer.fit(ForestSemSegTransformer(args), ForestDataset(args))
trainer.validate(ForestSemSegTransformer(args), ForestDataset(args)) args.normalize_point_cloud = True
trainer.validate(
ForestSemSegTransformer(args),
ForestDataset(args),
ckpt_path='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/log/partseg/Hengshuang/lightning_logs/version_71/checkpoints/epoch=9-step=30.ckpt'
)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment