From 04bde564410820b3bcabd84f80fd9afeb9a05656 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Fri, 24 Mar 2023 13:15:41 +0100
Subject: [PATCH] parallel implementation of dgcnn in pl

---
 dgcnn/dgcnn_train_pl.py      | 63 ++++++++++++++++++++++++++++++++++++
 dgcnn/shapenet_data_dgcnn.py |  2 --
 2 files changed, 63 insertions(+), 2 deletions(-)
 create mode 100644 dgcnn/dgcnn_train_pl.py

diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
new file mode 100644
index 0000000..e6ba0c2
--- /dev/null
+++ b/dgcnn/dgcnn_train_pl.py
@@ -0,0 +1,63 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from shapenet_data_dgcnn import ShapenetDataDgcnn
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+from model import DGCNN
+
+
+
+class DGCNNLightning(pl.LightningModule):
+    def __init__(self, num_classes):
+        super().__init__()
+        self.dgcnn = DGCNN(num_classes)
+        
+    def forward(self, x):
+        return self.dgcnn(x)
+    
+    def training_step(self, batch, batch_idx):
+        points, _, class_name = batch
+        pred = self(points)
+        loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
+        self.log('train_loss', loss)
+        return loss
+    
+    def configure_optimizers(self):
+        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
+        return optimizer
+
+ # get data 
+shapenet_data = ShapenetDataDgcnn(
+      root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+      npoints=256,
+      return_cls_label=True,
+      small_data=False,
+      small_data_size=1000,
+      just_one_class=False,
+      split='train',
+      norm=True
+      )
+
+  # create a dataloader
+dataloader = torch.utils.data.DataLoader(
+        shapenet_data,
+        batch_size=4,
+        shuffle=True,
+        num_workers=8,
+        drop_last=True
+        )
+
+
+# Initialize a trainer
+
+wandb_logger = WandbLogger(project="dgcnn", name="dgcnn", entity="maciej-wielgosz-nibio")
+
+trainer = pl.Trainer(accelerator="auto", devices=[0], max_epochs=3, logger=wandb_logger, gpus=1)
+
+# Initialize a model
+model = DGCNNLightning(num_classes=16)
+wandb_logger.watch(model)
+# Train the model on gpu
+trainer.fit(model, dataloader)
diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py
index 275ca4e..501caff 100644
--- a/dgcnn/shapenet_data_dgcnn.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -165,8 +165,6 @@ class ShapenetDataDgcnn(object):
             point_set = self.normalize(point_set)
 
         choice = np.random.choice(len(point_set), self.npoints, replace=True)
-        # chose the first npoints
-        choice = np.arange(self.npoints)
 
         point_set = point_set[choice, :]
         point_set = point_set.astype(np.float32)
-- 
GitLab