Skip to content
Snippets Groups Projects
Commit 41e3da5f authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

parallel implemenation pth lighting

parent 86333368
No related branches found
No related tags found
No related merge requests found
import os
import wandb
import random
import numpy as np
from tqdm.auto import tqdm
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from torch.optim.lr_scheduler import StepLR
import torch
import torch.nn.functional as F
from torch_scatter import scatter
from torchmetrics.functional import jaccard_index
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, DynamicEdgeConv
from my_data_loader import MyData
wandb_project = "pyg-point-cloud" #@param {"type": "string"} , maciej-wielgosz-nibio
wandb_run_name = "train-dgcnn" #@param {"type": "string"}
wandb.init(
entity="maciej-wielgosz-nibio",
project=wandb_project,
name=wandb_run_name,
job_type="train"
)
config = wandb.config
config.seed = 42
config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
random.seed(config.seed)
torch.manual_seed(config.seed)
device = torch.device(config.device)
config.category = 'Car' #@param ["Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"] {type:"raw"}
config.random_jitter_translation = 1e-2
config.random_rotation_interval_x = 15
config.random_rotation_interval_y = 15
config.random_rotation_interval_z = 15
config.validation_split = 0.2
config.batch_size = 4
config.num_workers = 6
config.num_nearest_neighbours = 30
config.aggregation_operator = "max"
config.dropout = 0.5
config.initial_lr = 1e-3
config.lr_scheduler_step_size = 5
config.gamma = 0.8
config.epochs = 1
transform = T.Compose([
T.RandomJitter(config.random_jitter_translation),
T.RandomRotate(config.random_rotation_interval_x, axis=0),
T.RandomRotate(config.random_rotation_interval_y, axis=1),
T.RandomRotate(config.random_rotation_interval_z, axis=2)
])
pre_transform = T.NormalizeScale()
# dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/maciek_data/plane_maciek"
dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/nibio_data_no_commas"
train_val_dataset = MyData(
dataset_path, config.category, split='trainval',
transform=transform, pre_transform=pre_transform
)
segmentation_class_frequency = {}
for idx in tqdm(range(len(train_val_dataset))):
pc_viz = train_val_dataset[idx].pos.numpy().tolist()
segmentation_label = train_val_dataset[idx].y.numpy().tolist()
for label in set(segmentation_label):
segmentation_class_frequency[label] = segmentation_label.count(label)
class_offset = min(list(segmentation_class_frequency.keys()))
print("Class Offset:", class_offset)
for idx in range(len(train_val_dataset)):
train_val_dataset[idx].y -= class_offset
num_train_examples = int((1 - config.validation_split) * len(train_val_dataset))
train_dataset = train_val_dataset[:num_train_examples]
val_dataset = train_val_dataset[num_train_examples:]
train_loader = DataLoader(
train_dataset, batch_size=config.batch_size,
shuffle=True, num_workers=config.num_workers
)
val_loader = DataLoader(
val_dataset, batch_size=config.batch_size,
shuffle=False, num_workers=config.num_workers
)
visualization_loader = DataLoader(
val_dataset[:10], batch_size=1,
shuffle=False, num_workers=config.num_workers
)
class DGCNN(torch.nn.Module):
def __init__(self, out_channels, k=30, aggr='max'):
super().__init__()
self.conv1 = DynamicEdgeConv(
MLP([2 * 6, 64, 64]), k, aggr
)
self.conv2 = DynamicEdgeConv(
MLP([2 * 64, 64, 64]), k, aggr
)
self.conv3 = DynamicEdgeConv(
MLP([2 * 64, 64, 64]), k, aggr
)
self.mlp = MLP(
[3 * 64, 1024, 256, 128, out_channels],
dropout=0.5, norm=None
)
def forward(self, data):
x, pos, batch = data.x, data.pos, data.batch
x0 = torch.cat([x, pos], dim=-1)
x1 = self.conv1(x0, batch)
x2 = self.conv2(x1, batch)
x3 = self.conv3(x2, batch)
out = self.mlp(torch.cat([x1, x2, x3], dim=1))
return F.log_softmax(out, dim=1)
config.num_classes = train_dataset.num_classes
model = DGCNN(
out_channels=train_dataset.num_classes,
k=config.num_nearest_neighbours,
aggr=config.aggregation_operator
).to(device)
# Define a new class that extends pl.LightningModule
class MyModel(pl.LightningModule):
def __init__(self, config):
super(MyModel, self).__init__()
self.config = config
self.model = DGCNN(
out_channels=self.config.num_classes,
k=self.config.num_nearest_neighbours,
aggr=self.config.aggregation_operator
)
def forward(self, data):
return self.model(data)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.initial_lr)
scheduler = StepLR(optimizer, step_size=self.config.lr_scheduler_step_size, gamma=self.config.gamma)
return [optimizer], [scheduler]
def training_step(self, batch, batch_idx):
data = batch
outs = self(data)
loss = F.nll_loss(outs, data.y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
data = batch
outs = self(data)
loss = F.nll_loss(outs, data.y)
self.log('val_loss', loss)
return loss
# ...
wandb_logger = WandbLogger(name=wandb_run_name, project=wandb_project, entity="maciej-wielgosz-nibio")
config.num_classes = train_dataset.num_classes
model = MyModel(config)
trainer = Trainer(max_epochs=config.epochs, gpus=1 if torch.cuda.is_available() else None, logger=wandb_logger)
trainer.fit(model, train_loader, val_loader)
wandb.finish()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment