From 04bde564410820b3bcabd84f80fd9afeb9a05656 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Fri, 24 Mar 2023 13:15:41 +0100 Subject: [PATCH] parallel implementation of dgcnn in pl --- dgcnn/dgcnn_train_pl.py | 63 ++++++++++++++++++++++++++++++++++++ dgcnn/shapenet_data_dgcnn.py | 2 -- 2 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 dgcnn/dgcnn_train_pl.py diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py new file mode 100644 index 0000000..e6ba0c2 --- /dev/null +++ b/dgcnn/dgcnn_train_pl.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from shapenet_data_dgcnn import ShapenetDataDgcnn +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from model import DGCNN + + + +class DGCNNLightning(pl.LightningModule): + def __init__(self, num_classes): + super().__init__() + self.dgcnn = DGCNN(num_classes) + + def forward(self, x): + return self.dgcnn(x) + + def training_step(self, batch, batch_idx): + points, _, class_name = batch + pred = self(points) + loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) + self.log('train_loss', loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=0.0001) + return optimizer + + # get data +shapenet_data = ShapenetDataDgcnn( + root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', + npoints=256, + return_cls_label=True, + small_data=False, + small_data_size=1000, + just_one_class=False, + split='train', + norm=True + ) + + # create a dataloader +dataloader = torch.utils.data.DataLoader( + shapenet_data, + batch_size=4, + shuffle=True, + num_workers=8, + drop_last=True + ) + + +# Initialize a trainer + +wandb_logger = WandbLogger(project="dgcnn", name="dgcnn", entity="maciej-wielgosz-nibio") + +trainer = pl.Trainer(accelerator="auto", devices=[0], max_epochs=3, logger=wandb_logger, gpus=1) + +# Initialize a model +model = DGCNNLightning(num_classes=16) +wandb_logger.watch(model) +# Train the model on gpu +trainer.fit(model, dataloader) diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index 275ca4e..501caff 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -165,8 +165,6 @@ class ShapenetDataDgcnn(object): point_set = self.normalize(point_set) choice = np.random.choice(len(point_set), self.npoints, replace=True) - # chose the first npoints - choice = np.arange(self.npoints) point_set = point_set[choice, :] point_set = point_set.astype(np.float32) -- GitLab