From 41e3da5f2af27600bc0f566c58403a6a5fe8384d Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Thu, 1 Jun 2023 11:45:46 +0200 Subject: [PATCH] parallel implemenation pth lighting --- PyG_implementation/dgcnn_parallel.py | 188 +++++++++++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 PyG_implementation/dgcnn_parallel.py diff --git a/PyG_implementation/dgcnn_parallel.py b/PyG_implementation/dgcnn_parallel.py new file mode 100644 index 0000000..3c7ca3b --- /dev/null +++ b/PyG_implementation/dgcnn_parallel.py @@ -0,0 +1,188 @@ +import os +import wandb +import random +import numpy as np +from tqdm.auto import tqdm +import pytorch_lightning as pl +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import WandbLogger +from torch.optim.lr_scheduler import StepLR +import torch +import torch.nn.functional as F +from torch_scatter import scatter +from torchmetrics.functional import jaccard_index + +import torch_geometric.transforms as T +from torch_geometric.datasets import ShapeNet +from torch_geometric.loader import DataLoader +from torch_geometric.nn import MLP, DynamicEdgeConv + +from my_data_loader import MyData + + +wandb_project = "pyg-point-cloud" #@param {"type": "string"} , maciej-wielgosz-nibio +wandb_run_name = "train-dgcnn" #@param {"type": "string"} + +wandb.init( + entity="maciej-wielgosz-nibio", + project=wandb_project, + name=wandb_run_name, + job_type="train" + ) + +config = wandb.config + +config.seed = 42 +config.device = 'cuda' if torch.cuda.is_available() else 'cpu' + +random.seed(config.seed) +torch.manual_seed(config.seed) +device = torch.device(config.device) + +config.category = 'Car' #@param ["Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"] {type:"raw"} +config.random_jitter_translation = 1e-2 +config.random_rotation_interval_x = 15 +config.random_rotation_interval_y = 15 +config.random_rotation_interval_z = 15 +config.validation_split = 0.2 +config.batch_size = 4 +config.num_workers = 6 + +config.num_nearest_neighbours = 30 +config.aggregation_operator = "max" +config.dropout = 0.5 +config.initial_lr = 1e-3 +config.lr_scheduler_step_size = 5 +config.gamma = 0.8 + +config.epochs = 1 + + +transform = T.Compose([ + T.RandomJitter(config.random_jitter_translation), + T.RandomRotate(config.random_rotation_interval_x, axis=0), + T.RandomRotate(config.random_rotation_interval_y, axis=1), + T.RandomRotate(config.random_rotation_interval_z, axis=2) +]) +pre_transform = T.NormalizeScale() + + +# dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/maciek_data/plane_maciek" +dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/nibio_data_no_commas" + +train_val_dataset = MyData( + dataset_path, config.category, split='trainval', + transform=transform, pre_transform=pre_transform +) + +segmentation_class_frequency = {} +for idx in tqdm(range(len(train_val_dataset))): + pc_viz = train_val_dataset[idx].pos.numpy().tolist() + segmentation_label = train_val_dataset[idx].y.numpy().tolist() + for label in set(segmentation_label): + segmentation_class_frequency[label] = segmentation_label.count(label) +class_offset = min(list(segmentation_class_frequency.keys())) +print("Class Offset:", class_offset) + +for idx in range(len(train_val_dataset)): + train_val_dataset[idx].y -= class_offset + +num_train_examples = int((1 - config.validation_split) * len(train_val_dataset)) +train_dataset = train_val_dataset[:num_train_examples] +val_dataset = train_val_dataset[num_train_examples:] + +train_loader = DataLoader( + train_dataset, batch_size=config.batch_size, + shuffle=True, num_workers=config.num_workers +) +val_loader = DataLoader( + val_dataset, batch_size=config.batch_size, + shuffle=False, num_workers=config.num_workers +) +visualization_loader = DataLoader( + val_dataset[:10], batch_size=1, + shuffle=False, num_workers=config.num_workers +) + + +class DGCNN(torch.nn.Module): + def __init__(self, out_channels, k=30, aggr='max'): + super().__init__() + self.conv1 = DynamicEdgeConv( + MLP([2 * 6, 64, 64]), k, aggr + ) + self.conv2 = DynamicEdgeConv( + MLP([2 * 64, 64, 64]), k, aggr + ) + self.conv3 = DynamicEdgeConv( + MLP([2 * 64, 64, 64]), k, aggr + ) + self.mlp = MLP( + [3 * 64, 1024, 256, 128, out_channels], + dropout=0.5, norm=None + ) + + def forward(self, data): + x, pos, batch = data.x, data.pos, data.batch + x0 = torch.cat([x, pos], dim=-1) + x1 = self.conv1(x0, batch) + x2 = self.conv2(x1, batch) + x3 = self.conv3(x2, batch) + out = self.mlp(torch.cat([x1, x2, x3], dim=1)) + return F.log_softmax(out, dim=1) + + +config.num_classes = train_dataset.num_classes + +model = DGCNN( + out_channels=train_dataset.num_classes, + k=config.num_nearest_neighbours, + aggr=config.aggregation_operator +).to(device) + +# Define a new class that extends pl.LightningModule +class MyModel(pl.LightningModule): + + def __init__(self, config): + super(MyModel, self).__init__() + + self.config = config + self.model = DGCNN( + out_channels=self.config.num_classes, + k=self.config.num_nearest_neighbours, + aggr=self.config.aggregation_operator + ) + + def forward(self, data): + return self.model(data) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.config.initial_lr) + scheduler = StepLR(optimizer, step_size=self.config.lr_scheduler_step_size, gamma=self.config.gamma) + return [optimizer], [scheduler] + + def training_step(self, batch, batch_idx): + data = batch + outs = self(data) + loss = F.nll_loss(outs, data.y) + self.log('train_loss', loss) + return loss + + def validation_step(self, batch, batch_idx): + data = batch + outs = self(data) + loss = F.nll_loss(outs, data.y) + self.log('val_loss', loss) + return loss + +# ... + +wandb_logger = WandbLogger(name=wandb_run_name, project=wandb_project, entity="maciej-wielgosz-nibio") + +config.num_classes = train_dataset.num_classes +model = MyModel(config) + +trainer = Trainer(max_epochs=config.epochs, gpus=1 if torch.cuda.is_available() else None, logger=wandb_logger) +trainer.fit(model, train_loader, val_loader) + +wandb.finish() -- GitLab