Skip to content
Snippets Groups Projects
Commit 0441820e authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

update of pl forest transformer

parent 8338af5b
Branches
No related tags found
No related merge requests found
import argparse
import importlib
import os
import shutil
import hydra
import numpy as np
import omegaconf
......@@ -45,6 +47,18 @@ class ForestSemSegTransformer(pl.LightningModule):
for label in self.seg_classes[cat]:
self.seg_label_to_cat[label] = cat
self.results_dir = hydra.utils.to_absolute_path('results')
# create folder to save the results las files
if not os.path.exists(self.results_dir):
os.mkdir(self.results_dir)
self.test_dataset = Dataset(
root=hydra.utils.to_absolute_path('data/forest_txt/validation_txt/'),
npoints=self.conf.num_point,
normal_channel=self.conf.normal,
normalize_point_cloud=False)
def forward(self, data):
return self.model(data)
......@@ -72,12 +86,79 @@ class ForestSemSegTransformer(pl.LightningModule):
else:
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.conf.learning_rate, momentum=0.9)
# update learning rate
lr = max(self.conf.learning_rate * (self.conf.lr_decay ** (self.current_epoch // self.conf.step_size)), self.learning_rate_clip)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return optimizer
def random_scale_point_cloud_pth(self, batch_data, scale_low=0.8, scale_high=1.25):
"""Randomly scale the point cloud. Scale is per point cloud.
Input:
batch_data: (B, N, 3) tensor, original batch of point clouds
scale_low: float, lower bound of the random scale factor
scale_high: float, upper bound of the random scale factor
Return:
(B, N, 3) tensor, scaled batch of point clouds
"""
B, N, C = batch_data.size()
scales = torch.FloatTensor(B).uniform_(scale_low, scale_high)
scales = scales.view(B, 1, 1) # Reshape for broadcasting
batch_data *= scales.to(batch_data.device)
return batch_data
def shift_point_cloud_pth(self, batch_data, shift_range=0.1):
"""Randomly shift point cloud. Shift is per point cloud.
Input:
batch_data: (B, N, 3) tensor, original batch of point clouds
shift_range: float, maximum distance to shift each point
Return:
(B, N, 3) tensor, shifted batch of point clouds
"""
B, N, C = batch_data.size()
shifts = torch.FloatTensor(B, 1, 3).uniform_(-shift_range, shift_range)
shifts = shifts.to(batch_data.device)
batch_data += shifts
return batch_data
def shift_point_cloud_np(self, batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
# convert to numpy array
batch_data = batch_data.cpu().numpy()
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
# convert back to torch tensor
batch_data = torch.from_numpy(batch_data).cuda()
return batch_data
def random_scale_point_cloud_np(self, batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
# convert to numpy array
batch_data = batch_data.cpu().numpy()
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
# convert back to torch tensor
batch_data = torch.from_numpy(batch_data).cuda()
return batch_data
# define epoch se
def on_train_epoch_start(self) -> None:
......@@ -89,17 +170,25 @@ class ForestSemSegTransformer(pl.LightningModule):
self.model = self.model.train()
self.acc_train_mean_correct = []
# update learning rate
lr = max(self.conf.learning_rate * (self.conf.lr_decay ** (self.current_epoch // self.conf.step_size)), self.learning_rate_clip)
for param_group in self.optimizers().param_groups:
param_group['lr'] = lr
def on_training_epoch_end(self):
self.log("train_acc", np.mean(self.acc_train_mean_correct), on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log("epoch", self.current_epoch, on_step=True, on_epoch=True, prog_bar=True, logger=True)
def training_step(self, batch, batch_idx):
points, label = batch
points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
# points[:, :, 0:3] = self.random_scale_point_cloud_np(points[:, :, 0:3])
points[:, :, 0:3] = self.random_scale_point_cloud_pth(points[:, :, 0:3])
# points[:, :, 0:3] = self.shift_point_cloud_np(points[:, :, 0:3])
points[:, :, 0:3] = self.shift_point_cloud_pth(points[:, :, 0:3])
points = torch.Tensor(points)
points, label = points.float().cuda(), label.long().cuda()
......@@ -107,6 +196,7 @@ class ForestSemSegTransformer(pl.LightningModule):
[points, self.to_categorical(torch.ones((points.shape[0], 1),
dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
seg_pred = seg_pred.contiguous().view(-1, self.conf.num_part)
target = label.view(-1, 1)[:, 0]
pred_choice = seg_pred.data.max(1)[1]
......@@ -172,6 +262,31 @@ class ForestSemSegTransformer(pl.LightningModule):
logits = cur_pred_val_logits[i, :, :]
cur_pred_val[i, :] = np.argmax(logits[:, self.seg_classes[cat]], 1) + self.seg_classes[cat][0]
# get x,y,z coordinates of points
points = points.cpu().data.numpy()
points = points[:, :, 0:3]
points_pd = np.concatenate([points, np.expand_dims(cur_pred_val, axis=2)], axis=2)
points_gt = np.concatenate([points, np.expand_dims(target, axis=2)], axis=2)
# save points as text files in the results folder and preserve the same name as the original txt file
for i in range(cur_batch_size):
np.savetxt(os.path.join(self.results_dir,
self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1].replace('.txt', '_pred.txt')
),
points_pd[i, :, :], fmt='%f %f %f %d')
# copy original txt file to the results folder
shutil.copy(self.test_dataset.datapath[batch_idx * self.conf.batch_size + i],
os.path.join(self.results_dir, self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1]
))
# save ground truth labels as text files in the results folder and preserve the same name as the original txt file
np.savetxt(os.path.join(
self.results_dir,
self.test_dataset.datapath[batch_idx * self.conf.batch_size + i].split('/')[-1].replace('.txt', '_gt.txt')
),
points_gt[i, :, :], fmt='%f %f %f %d')
correct = np.sum(cur_pred_val == target)
self.total_correct += correct
self.total_seen += (cur_batch_size * NUM_POINT)
......@@ -196,8 +311,6 @@ class ForestSemSegTransformer(pl.LightningModule):
return self.test_metrics
class ForestDataset(pl.LightningDataModule):
def __init__(self, conf):
super().__init__()
......@@ -208,8 +321,18 @@ class ForestDataset(pl.LightningDataModule):
self.train_dataset = None
def setup(self, stage=None):
self.test_dataset = Dataset(root=self.test_dataset_path, npoints=self.conf.num_point, normal_channel=self.conf.normal, normalize_point_cloud=True)
self.train_dataset = Dataset(root=self.train_dataset_path, npoints=self.conf.num_point, normal_channel=self.conf.normal, normalize_point_cloud=True)
self.test_dataset = Dataset(
root=self.test_dataset_path,
npoints=self.conf.num_point,
normal_channel=self.conf.normal,
normalize_point_cloud=self.conf.normalize_point_cloud
)
self.train_dataset = Dataset(
root=self.train_dataset_path,
npoints=self.conf.num_point,
normal_channel=self.conf.normal,
normalize_point_cloud=True
)
def train_dataloader(self):
return torch.utils.data.DataLoader(
......@@ -232,10 +355,18 @@ class ForestDataset(pl.LightningDataModule):
def main(args):
omegaconf.OmegaConf.set_struct(args, False)
# add a parameter to args
trainer = pl.Trainer(gpus=1, max_epochs=args.epoch)
# args.normalize_point_cloud = True
# trainer.fit(ForestSemSegTransformer(args), ForestDataset(args))
trainer.validate(ForestSemSegTransformer(args), ForestDataset(args))
args.normalize_point_cloud = True
trainer.validate(
ForestSemSegTransformer(args),
ForestDataset(args),
ckpt_path='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/log/partseg/Hengshuang/lightning_logs/version_71/checkpoints/epoch=9-step=30.ckpt'
)
if __name__ == '__main__':
main()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment