Skip to content
Snippets Groups Projects
Commit 04bde564 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

parallel implementation of dgcnn in pl

parent 36f46314
Branches
No related tags found
No related merge requests found
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from shapenet_data_dgcnn import ShapenetDataDgcnn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from model import DGCNN
class DGCNNLightning(pl.LightningModule):
def __init__(self, num_classes):
super().__init__()
self.dgcnn = DGCNN(num_classes)
def forward(self, x):
return self.dgcnn(x)
def training_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
return optimizer
# get data
shapenet_data = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=256,
return_cls_label=True,
small_data=False,
small_data_size=1000,
just_one_class=False,
split='train',
norm=True
)
# create a dataloader
dataloader = torch.utils.data.DataLoader(
shapenet_data,
batch_size=4,
shuffle=True,
num_workers=8,
drop_last=True
)
# Initialize a trainer
wandb_logger = WandbLogger(project="dgcnn", name="dgcnn", entity="maciej-wielgosz-nibio")
trainer = pl.Trainer(accelerator="auto", devices=[0], max_epochs=3, logger=wandb_logger, gpus=1)
# Initialize a model
model = DGCNNLightning(num_classes=16)
wandb_logger.watch(model)
# Train the model on gpu
trainer.fit(model, dataloader)
......@@ -165,8 +165,6 @@ class ShapenetDataDgcnn(object):
point_set = self.normalize(point_set)
choice = np.random.choice(len(point_set), self.npoints, replace=True)
# chose the first npoints
choice = np.arange(self.npoints)
point_set = point_set[choice, :]
point_set = point_set.astype(np.float32)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment