diff --git a/dgcnn/dgcnn_shape_net_inference.py b/dgcnn/dgcnn_shape_net_inference.py index 079a7c6750816dabccc4ede3943b92b3d144c196..36dcba693c6aae1a1bd0e88f22ed15611b101fdb 100644 --- a/dgcnn/dgcnn_shape_net_inference.py +++ b/dgcnn/dgcnn_shape_net_inference.py @@ -5,14 +5,12 @@ import torch.nn.init as init from my_models.model_shape_net import DgcnShapeNet as DGCNN - - def main(): - dgcnn = DGCNN(50) + dgcnn = DGCNN(50, 4) dgcnn.eval() # simple test input_tensor = torch.randn(2, 128, 3) - label_one_hot = torch.zeros((2, 16)) + label_one_hot = torch.zeros((2, 4)) print(input_tensor.shape) out = dgcnn(input_tensor, label_one_hot) print(out.shape) diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index e49a9433c0b5ffcb95667ff53e70b217c4bff044..dd1366c46b2403dd38471a2e6de054ee5b580af1 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - num_classes=config['data']['just_four_classes'], + num_classes=config['data']['num_classes'], split='train', norm=config['data']['norm'], augmnetation=config['data']['augmentation'] @@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - num_classes=config['data']['just_four_classes'], + num_classes=config['data']['num_classes'], split='test', norm=config['data']['norm'] ) @@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - num_classes=config['data']['just_four_classes'], + num_classes=config['data']['num_classes'], split='test', norm=config['data']['norm'] ) @@ -217,10 +217,8 @@ else: ) # Initialize a model -if config['data']['just_four_classes']: - num_classes = 4 -else: - num_classes = 16 +num_classes = int(config['data']['num_classes']) + model = DGCNNLightning(num_classes=num_classes) if config['wandb']['use_wandb']: diff --git a/dgcnn/dgcnn_train_shape_net.py b/dgcnn/dgcnn_train_shape_net.py index 9d9b9b96e626f302ae832e399314fc0f723311f4..9aeb3884f3ae603004c5fe9d46580924a2fcb013 100644 --- a/dgcnn/dgcnn_train_shape_net.py +++ b/dgcnn/dgcnn_train_shape_net.py @@ -31,12 +31,12 @@ def train(): # get data shapenet_data = ShapenetDataDgcnn( root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', - npoints=1024, + npoints=32, return_cls_label=True, - small_data=False, - small_data_size=300, - just_four_classes=False, - data_augmentation=True, + small_data=True, + small_data_size=10, + num_classes=True, + data_augmentation=False, split='train', norm=True ) @@ -44,7 +44,7 @@ def train(): # create a dataloader dataloader = torch.utils.data.DataLoader( shapenet_data, - batch_size=4, + batch_size=8, shuffle=True, num_workers=8, drop_last=True @@ -77,6 +77,11 @@ def train(): # print(f"Batch: {i}") points, labels, class_name = data + # log data to wandb + if use_wandb: + wandb.log({"points": wandb.Object3D(points[0, :, :].cpu().numpy())}) + # wandb.log({"labels": labels.cpu().numpy()}) + # wandb.log({"class_name": class_name.cpu().numpy()}) label_one_hot = np.zeros((class_name.shape[0], 16)) for idx in range(class_name.shape[0]): label_one_hot[idx, class_name[idx]] = 1 diff --git a/dgcnn/dgcnn_train_shape_net_cat_1-4.py b/dgcnn/dgcnn_train_shape_net_cat_1-4.py new file mode 100644 index 0000000000000000000000000000000000000000..3c25b49d3e93d7aa2c111b4e54b6358317e88699 --- /dev/null +++ b/dgcnn/dgcnn_train_shape_net_cat_1-4.py @@ -0,0 +1,122 @@ +import random +import numpy as np +import torch +import os +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from matplotlib import pyplot as plt +from shapenet_data_dgcnn import ShapenetDataDgcnn +# import IoU from torchmetrics +from torchmetrics import Accuracy +from torchmetrics import JaccardIndex as IoU +from torch.optim.lr_scheduler import StepLR + +import wandb + +use_wandb = False +num_classes = 4 +batch_size = 4 + +# create a wandb run +if use_wandb: + run = wandb.init(project="dgcnn", entity="maciej-wielgosz-nibio") + + +from my_models.model_shape_net import DgcnShapeNet as DGCNN + + +def train(): + seg_num_all = 50 + dgcnn = DGCNN( + seg_num_all=seg_num_all, + num_classes=num_classes + ).cuda() + dgcnn.train() + + # get data + shapenet_data = ShapenetDataDgcnn( + root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', + npoints=32, + return_cls_label=True, + small_data=True, + small_data_size=10, + num_classes=num_classes, + data_augmentation=False, + split='train', + norm=True + ) + + # create a dataloader + dataloader = torch.utils.data.DataLoader( + shapenet_data, + batch_size=batch_size, + shuffle=True, + num_workers=8, + drop_last=True + ) + + # create a optimizer + optimizer = torch.optim.Adam(dgcnn.parameters(), lr=0.001, weight_decay=1e-4) + scheduler = StepLR(optimizer, step_size=20, gamma=0.5) + if use_wandb: + # create a config wandb + wandb.config.update({ + "batch_size": batch_size, + "learning_rate": 0.01, + "optimizer": "Adam", + "loss_function": "cross_entropy" + }) + + # train + iou = IoU(num_classes=50, task='multiclass', average='macro').cuda() + acc = Accuracy(num_classes=50, compute_on_step=False, dist_sync_on_step=False, task='multiclass').cuda() + + for epoch in range(500): + iou.reset() + print(f"Epoch: {epoch}") + if use_wandb: + wandb.log({"epoch": epoch}) + for i, data in enumerate(dataloader, 0): + # print(f"Batch: {i}") + + points, labels, class_name = data + print('class_name', class_name) + label_one_hot = np.zeros((class_name.shape[0], num_classes)) + for idx in range(class_name.shape[0]): + + label_one_hot[idx, class_name[idx]] = 1 + label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) + + points = points.cuda() + labels = labels.cuda() + label_one_hot = label_one_hot.cuda() + + optimizer.zero_grad() + pred = dgcnn(points, label_one_hot) + + loss = F.cross_entropy( + pred, labels, reduction='mean', ignore_index=255) + loss.backward() + optimizer.step() + if optimizer.param_groups[0]['lr'] > 1e-5: + scheduler.step() + + # print lose every 10 batches + if i % 100 == 0: + print(loss.item()) + pred_softmax = F.softmax(pred, dim=1) + pred_argmax = torch.argmax(pred_softmax, dim=1) + gt_one_hot = F.one_hot(labels, num_classes=50) + pred_one_hot = F.one_hot(pred_argmax, num_classes=50) + print("loss : ", loss.item()) + print("IoU : ", iou(pred_one_hot, gt_one_hot)) + print("Acc : ", acc(pred_argmax, labels)) + if use_wandb: + wandb.log({"loss": loss.item()}) + wandb.log({"iou": iou(pred, labels)}) + wandb.log({"acc": acc(pred, labels)}) + + +if __name__ == '__main__': + train() \ No newline at end of file diff --git a/dgcnn/my_models/model_shape_net.py b/dgcnn/my_models/model_shape_net.py index fc275ac8734d6df2cc9547bc14d46b4fb0b5b8eb..244f4eaf70db12d8a5ee9a9d4b5e55785e62ff37 100644 --- a/dgcnn/my_models/model_shape_net.py +++ b/dgcnn/my_models/model_shape_net.py @@ -43,7 +43,7 @@ class DgcnShapeNet(nn.Module): - def forward(self, x, class_label): + def forward(self, x, class_label_one_hot): # Apply Transform_Net on input point cloud trans_matrix = self.transform_net(x) @@ -63,10 +63,10 @@ class DgcnShapeNet(nn.Module): x = self.conv5(x5) # (batch_size, 1024, num_points) x = x.max(dim=-1, keepdim=True)[0] # (batch_size, 1024) - class_label = class_label.view(batch_size, -1, 1) # (batch_size, num_categoties, 1) - class_label = self.conv7(class_label) # (batch_size, num_categoties, 1) -> (batch_size, 64, 1) + class_label_one_hot = class_label_one_hot.view(batch_size, -1, 1) # (batch_size, num_categoties, 1) + class_label_one_hot = self.conv7(class_label_one_hot) # (batch_size, num_categoties, 1) -> (batch_size, 64, 1) - x = torch.cat((x, class_label), dim=1) # (batch_size, 1088, 1) + x = torch.cat((x, class_label_one_hot), dim=1) # (batch_size, 1088, 1) x = x.repeat(1, 1, num_points) # (batch_size, 1088, num_points) x = torch.cat((x, x1, x2), dim=1) # (batch_size, 1088+64*3, num_points) diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index d518a639e1fb2a97cd4ed6379664b5ff2820188f..71fb5f89c8282419bc70f5c1562c32417897a0a5 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object): small_data=False, small_data_size=10, return_cls_label=False, - num_classes=1, # None - all classes (50), 1 - one class, 2 - two classes, max 4 + num_classes=1, # you can choose 1, 2, 3, 4, or 16 norm=False, augmnetation=False, data_augmentation=False @@ -241,12 +241,15 @@ class ShapenetDataDgcnn(object): class_name = self.val_data_file[index].split('/')[-2] # apply the mapper - # if self.num_classes: - # class_name = self.class_mapper_4_classes(class_name) - # else: - # class_name = self.class_mapper(class_name) + if self.num_classes in range(1, 5): + class_name = self.class_mapper_4_classes(class_name) + elif self.num_classes == 16: + class_name = self.class_mapper(class_name) + else: + raise ValueError('num_classes not in range, should be in range 1-4 or 16') + - class_name = self.class_mapper(class_name) + # class_name = self.class_mapper(class_name) # convert the class name to a number class_name = np.array(class_name, dtype=np.int64) @@ -254,6 +257,7 @@ class ShapenetDataDgcnn(object): # map to tensor # class_name = torch.from_numpy(class_name) + if self.return_cls_label: return point_set, labels, class_name else: