From f1468350f2322b3c510f2ddba96f8b28f825dc74 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Mon, 24 Apr 2023 11:08:40 +0200 Subject: [PATCH] updated model for 1 - 4 categories shapenet segmentation --- dgcnn/dgcnn_shape_net_inference.py | 6 +- dgcnn/dgcnn_train_pl.py | 12 +-- dgcnn/dgcnn_train_shape_net.py | 17 ++-- dgcnn/dgcnn_train_shape_net_cat_1-4.py | 122 +++++++++++++++++++++++++ dgcnn/my_models/model_shape_net.py | 8 +- dgcnn/shapenet_data_dgcnn.py | 16 ++-- 6 files changed, 154 insertions(+), 27 deletions(-) create mode 100644 dgcnn/dgcnn_train_shape_net_cat_1-4.py diff --git a/dgcnn/dgcnn_shape_net_inference.py b/dgcnn/dgcnn_shape_net_inference.py index 079a7c6..36dcba6 100644 --- a/dgcnn/dgcnn_shape_net_inference.py +++ b/dgcnn/dgcnn_shape_net_inference.py @@ -5,14 +5,12 @@ import torch.nn.init as init from my_models.model_shape_net import DgcnShapeNet as DGCNN - - def main(): - dgcnn = DGCNN(50) + dgcnn = DGCNN(50, 4) dgcnn.eval() # simple test input_tensor = torch.randn(2, 128, 3) - label_one_hot = torch.zeros((2, 16)) + label_one_hot = torch.zeros((2, 4)) print(input_tensor.shape) out = dgcnn(input_tensor, label_one_hot) print(out.shape) diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index e49a943..dd1366c 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - num_classes=config['data']['just_four_classes'], + num_classes=config['data']['num_classes'], split='train', norm=config['data']['norm'], augmnetation=config['data']['augmentation'] @@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - num_classes=config['data']['just_four_classes'], + num_classes=config['data']['num_classes'], split='test', norm=config['data']['norm'] ) @@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - num_classes=config['data']['just_four_classes'], + num_classes=config['data']['num_classes'], split='test', norm=config['data']['norm'] ) @@ -217,10 +217,8 @@ else: ) # Initialize a model -if config['data']['just_four_classes']: - num_classes = 4 -else: - num_classes = 16 +num_classes = int(config['data']['num_classes']) + model = DGCNNLightning(num_classes=num_classes) if config['wandb']['use_wandb']: diff --git a/dgcnn/dgcnn_train_shape_net.py b/dgcnn/dgcnn_train_shape_net.py index 9d9b9b9..9aeb388 100644 --- a/dgcnn/dgcnn_train_shape_net.py +++ b/dgcnn/dgcnn_train_shape_net.py @@ -31,12 +31,12 @@ def train(): # get data shapenet_data = ShapenetDataDgcnn( root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', - npoints=1024, + npoints=32, return_cls_label=True, - small_data=False, - small_data_size=300, - just_four_classes=False, - data_augmentation=True, + small_data=True, + small_data_size=10, + num_classes=True, + data_augmentation=False, split='train', norm=True ) @@ -44,7 +44,7 @@ def train(): # create a dataloader dataloader = torch.utils.data.DataLoader( shapenet_data, - batch_size=4, + batch_size=8, shuffle=True, num_workers=8, drop_last=True @@ -77,6 +77,11 @@ def train(): # print(f"Batch: {i}") points, labels, class_name = data + # log data to wandb + if use_wandb: + wandb.log({"points": wandb.Object3D(points[0, :, :].cpu().numpy())}) + # wandb.log({"labels": labels.cpu().numpy()}) + # wandb.log({"class_name": class_name.cpu().numpy()}) label_one_hot = np.zeros((class_name.shape[0], 16)) for idx in range(class_name.shape[0]): label_one_hot[idx, class_name[idx]] = 1 diff --git a/dgcnn/dgcnn_train_shape_net_cat_1-4.py b/dgcnn/dgcnn_train_shape_net_cat_1-4.py new file mode 100644 index 0000000..3c25b49 --- /dev/null +++ b/dgcnn/dgcnn_train_shape_net_cat_1-4.py @@ -0,0 +1,122 @@ +import random +import numpy as np +import torch +import os +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from matplotlib import pyplot as plt +from shapenet_data_dgcnn import ShapenetDataDgcnn +# import IoU from torchmetrics +from torchmetrics import Accuracy +from torchmetrics import JaccardIndex as IoU +from torch.optim.lr_scheduler import StepLR + +import wandb + +use_wandb = False +num_classes = 4 +batch_size = 4 + +# create a wandb run +if use_wandb: + run = wandb.init(project="dgcnn", entity="maciej-wielgosz-nibio") + + +from my_models.model_shape_net import DgcnShapeNet as DGCNN + + +def train(): + seg_num_all = 50 + dgcnn = DGCNN( + seg_num_all=seg_num_all, + num_classes=num_classes + ).cuda() + dgcnn.train() + + # get data + shapenet_data = ShapenetDataDgcnn( + root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', + npoints=32, + return_cls_label=True, + small_data=True, + small_data_size=10, + num_classes=num_classes, + data_augmentation=False, + split='train', + norm=True + ) + + # create a dataloader + dataloader = torch.utils.data.DataLoader( + shapenet_data, + batch_size=batch_size, + shuffle=True, + num_workers=8, + drop_last=True + ) + + # create a optimizer + optimizer = torch.optim.Adam(dgcnn.parameters(), lr=0.001, weight_decay=1e-4) + scheduler = StepLR(optimizer, step_size=20, gamma=0.5) + if use_wandb: + # create a config wandb + wandb.config.update({ + "batch_size": batch_size, + "learning_rate": 0.01, + "optimizer": "Adam", + "loss_function": "cross_entropy" + }) + + # train + iou = IoU(num_classes=50, task='multiclass', average='macro').cuda() + acc = Accuracy(num_classes=50, compute_on_step=False, dist_sync_on_step=False, task='multiclass').cuda() + + for epoch in range(500): + iou.reset() + print(f"Epoch: {epoch}") + if use_wandb: + wandb.log({"epoch": epoch}) + for i, data in enumerate(dataloader, 0): + # print(f"Batch: {i}") + + points, labels, class_name = data + print('class_name', class_name) + label_one_hot = np.zeros((class_name.shape[0], num_classes)) + for idx in range(class_name.shape[0]): + + label_one_hot[idx, class_name[idx]] = 1 + label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) + + points = points.cuda() + labels = labels.cuda() + label_one_hot = label_one_hot.cuda() + + optimizer.zero_grad() + pred = dgcnn(points, label_one_hot) + + loss = F.cross_entropy( + pred, labels, reduction='mean', ignore_index=255) + loss.backward() + optimizer.step() + if optimizer.param_groups[0]['lr'] > 1e-5: + scheduler.step() + + # print lose every 10 batches + if i % 100 == 0: + print(loss.item()) + pred_softmax = F.softmax(pred, dim=1) + pred_argmax = torch.argmax(pred_softmax, dim=1) + gt_one_hot = F.one_hot(labels, num_classes=50) + pred_one_hot = F.one_hot(pred_argmax, num_classes=50) + print("loss : ", loss.item()) + print("IoU : ", iou(pred_one_hot, gt_one_hot)) + print("Acc : ", acc(pred_argmax, labels)) + if use_wandb: + wandb.log({"loss": loss.item()}) + wandb.log({"iou": iou(pred, labels)}) + wandb.log({"acc": acc(pred, labels)}) + + +if __name__ == '__main__': + train() \ No newline at end of file diff --git a/dgcnn/my_models/model_shape_net.py b/dgcnn/my_models/model_shape_net.py index fc275ac..244f4ea 100644 --- a/dgcnn/my_models/model_shape_net.py +++ b/dgcnn/my_models/model_shape_net.py @@ -43,7 +43,7 @@ class DgcnShapeNet(nn.Module): - def forward(self, x, class_label): + def forward(self, x, class_label_one_hot): # Apply Transform_Net on input point cloud trans_matrix = self.transform_net(x) @@ -63,10 +63,10 @@ class DgcnShapeNet(nn.Module): x = self.conv5(x5) # (batch_size, 1024, num_points) x = x.max(dim=-1, keepdim=True)[0] # (batch_size, 1024) - class_label = class_label.view(batch_size, -1, 1) # (batch_size, num_categoties, 1) - class_label = self.conv7(class_label) # (batch_size, num_categoties, 1) -> (batch_size, 64, 1) + class_label_one_hot = class_label_one_hot.view(batch_size, -1, 1) # (batch_size, num_categoties, 1) + class_label_one_hot = self.conv7(class_label_one_hot) # (batch_size, num_categoties, 1) -> (batch_size, 64, 1) - x = torch.cat((x, class_label), dim=1) # (batch_size, 1088, 1) + x = torch.cat((x, class_label_one_hot), dim=1) # (batch_size, 1088, 1) x = x.repeat(1, 1, num_points) # (batch_size, 1088, num_points) x = torch.cat((x, x1, x2), dim=1) # (batch_size, 1088+64*3, num_points) diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index d518a63..71fb5f8 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object): small_data=False, small_data_size=10, return_cls_label=False, - num_classes=1, # None - all classes (50), 1 - one class, 2 - two classes, max 4 + num_classes=1, # you can choose 1, 2, 3, 4, or 16 norm=False, augmnetation=False, data_augmentation=False @@ -241,12 +241,15 @@ class ShapenetDataDgcnn(object): class_name = self.val_data_file[index].split('/')[-2] # apply the mapper - # if self.num_classes: - # class_name = self.class_mapper_4_classes(class_name) - # else: - # class_name = self.class_mapper(class_name) + if self.num_classes in range(1, 5): + class_name = self.class_mapper_4_classes(class_name) + elif self.num_classes == 16: + class_name = self.class_mapper(class_name) + else: + raise ValueError('num_classes not in range, should be in range 1-4 or 16') + - class_name = self.class_mapper(class_name) + # class_name = self.class_mapper(class_name) # convert the class name to a number class_name = np.array(class_name, dtype=np.int64) @@ -254,6 +257,7 @@ class ShapenetDataDgcnn(object): # map to tensor # class_name = torch.from_numpy(class_name) + if self.return_cls_label: return point_set, labels, class_name else: -- GitLab