From 0e069d09fc43d85683d9bda8c75f8acc43351be8 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Fri, 31 Mar 2023 12:34:58 +0200 Subject: [PATCH] dedicated model for classification --- dgcnn/dgcnn_train_pl.py | 16 ++-- dgcnn/find_missing_data.py | 20 ----- dgcnn/get_size_of_dataset.py | 42 +++++++++++ dgcnn/model.py | 2 + dgcnn/model_class.py | 137 +++++++++++++++++++++++++++++++++++ dgcnn/shapenet_data_dgcnn.py | 7 +- 6 files changed, 195 insertions(+), 29 deletions(-) delete mode 100644 dgcnn/find_missing_data.py create mode 100644 dgcnn/get_size_of_dataset.py create mode 100644 dgcnn/model_class.py diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index 65b4fb5..6db78aa 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -4,7 +4,7 @@ import yaml from shapenet_data_dgcnn import ShapenetDataDgcnn import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger -from model import DGCNN +from model_class import DgcnnClass from torchmetrics import Accuracy, Precision, Recall from pytorch_lightning.callbacks import Callback from pytorch_lightning.strategies import DDPStrategy @@ -41,7 +41,7 @@ with open('config.yaml', 'r') as f: class DGCNNLightning(pl.LightningModule): def __init__(self, num_classes): super().__init__() - self.dgcnn = DGCNN(num_classes) + self.dgcnn = DgcnnClass(num_classes) # train define metrics self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes) self.train_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro') @@ -124,11 +124,11 @@ class DGCNNLightning(pl.LightningModule): self.test_recall.reset() def configure_optimizers(self): - # optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) - # return optimizer - optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9) - scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001) - return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"} + optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) + return optimizer + # optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9) + # scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001) + # return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"} # get train data @@ -152,7 +152,7 @@ shapenet_data_val = ShapenetDataDgcnn( small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], just_one_class=config['data']['just_one_class'], - split='val', + split='train', norm=config['data']['norm'] ) diff --git a/dgcnn/find_missing_data.py b/dgcnn/find_missing_data.py deleted file mode 100644 index d9bcc64..0000000 --- a/dgcnn/find_missing_data.py +++ /dev/null @@ -1,20 +0,0 @@ -from tqdm import tqdm -from shapenet_data_dgcnn import ShapenetDataDgcnn - -shapenet_data = ShapenetDataDgcnn( - root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', - npoints=128, - return_cls_label=True, - small_data=False, - small_data_size=1000, - just_one_class=False, - split='train', - norm=True - ) - -# read the data one by one and check if exists - -for i in tqdm(range(len(shapenet_data))): - data = shapenet_data[i] - if data[0].shape[0] != 128: - print(f"Data is None: {i}") \ No newline at end of file diff --git a/dgcnn/get_size_of_dataset.py b/dgcnn/get_size_of_dataset.py new file mode 100644 index 0000000..a68f35b --- /dev/null +++ b/dgcnn/get_size_of_dataset.py @@ -0,0 +1,42 @@ +from tqdm import tqdm +from shapenet_data_dgcnn import ShapenetDataDgcnn + +shapenet_data_train = ShapenetDataDgcnn( + root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', + npoints=128, + return_cls_label=True, + small_data=False, + small_data_size=1000, + just_one_class=True, + split='train', + norm=True + ) + +shapenet_data_test = ShapenetDataDgcnn( + root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', + npoints=128, + return_cls_label=True, + small_data=False, + small_data_size=1000, + just_one_class=True, + split='test', + norm=True + ) + +shapenet_data_val = ShapenetDataDgcnn( + root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', + npoints=128, + return_cls_label=True, + small_data=False, + small_data_size=1000, + just_one_class=True, + split='val', + norm=True + ) + +# print the length of the data +print(f"Train: {len(shapenet_data_train)}") +print(f"Test: {len(shapenet_data_test)}") +print(f"Val: {len(shapenet_data_val)}") + + diff --git a/dgcnn/model.py b/dgcnn/model.py index b3dae82..89c2781 100644 --- a/dgcnn/model.py +++ b/dgcnn/model.py @@ -100,6 +100,7 @@ class DGCNN(nn.Module): self.bn5, nn.LeakyReLU(negative_slope=0.2)) self.linear1 = nn.Linear(2048, 512, bias=False) + self.dropout = nn.Dropout(p=0.5) self.fc = nn.Sequential( nn.Linear(512, 256), @@ -143,5 +144,6 @@ class DGCNN(nn.Module): x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2) # x9 = x9.max(dim=1, keepdim=False)[0] x10 = self.linear1(x8) + x10 = self.dropout(x10) x11 = self.fc(x10) return x11 \ No newline at end of file diff --git a/dgcnn/model_class.py b/dgcnn/model_class.py new file mode 100644 index 0000000..10ccfce --- /dev/null +++ b/dgcnn/model_class.py @@ -0,0 +1,137 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +# TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166 +# TODO: There are mode conv layers there + + +class EdgeConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(EdgeConv, self).__init__() + self.in_channels = in_channels + self.conv = nn.Sequential( + nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU() + ) + + def forward(self, x, k=20): + #batch_size, num_points, in_channels + + batch_size, num_points, feature_dim = x.shape + x = x.view(batch_size, num_points,feature_dim ) + knn_indices = self.knn(x, k) + knn_gathered = self.gather_neighbors(x, knn_indices) + edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1) + edge_features = edge_features.view(batch_size, 2*feature_dim, num_points, k) + return self.conv(edge_features).transpose(2, 1) + + @staticmethod + def knn(x, k): + """Find the indices of the k nearest neighbors for each point in the input tensor.""" + batch_size, num_points, _ = x.shape + x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1) + x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1) + distances = torch.norm(x_expanded - x_transposed, dim=-1) + _, indices = distances.topk(k=k, dim=-1, largest=False) + return indices + + def gather_neighbors(self, x, knn_indices): + batch_size, num_points, _ = x.shape + _, _, k = knn_indices.shape + x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1) + neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, self.in_channels)) + return neighbors + + +class Transform_Net(nn.Module): + def __init__(self, k=3): + super(Transform_Net, self).__init__() + self.k = k + + self.conv1 = nn.Conv1d(k, 64, 1) + self.conv2 = nn.Conv1d(64, 128, 1) + self.conv3 = nn.Conv1d(128, 256, 1) + + + self.fc3 = nn.Linear(256, k*k) + + self.bn1 = nn.BatchNorm1d(64) + self.bn2 = nn.BatchNorm1d(128) + self.bn3 = nn.BatchNorm1d(256) + self.bn4 = nn.BatchNorm1d(128) + self.bn5 = nn.BatchNorm1d(64) + self.dropout_0 = nn.Dropout(p=0.5) + self.dropout_1 = nn.Dropout(p=0.5) + self.dropout_2 = nn.Dropout(p=0.5) + + def forward(self, x): + # Input shape: (batch_size, k, num_points) + x = x.transpose(2, 1) + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 256) + + x = self.fc3(x) + x = self.dropout_2(x) + + identity = torch.eye(self.k, dtype=x.dtype, device=x.device) + transform = x.view(-1, self.k, self.k) + identity + + transform = transform.transpose(2, 1) + + return transform + +class DgcnnClass(nn.Module): + def __init__(self, num_classes): + super(DgcnnClass, self).__init__() + self.transform_net = Transform_Net() + self.edge_conv1 = EdgeConv(3, 64) + self.edge_conv2 = EdgeConv(64, 128) + self.bn5 = nn.BatchNorm1d(256) + self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False), + self.bn5, + nn.LeakyReLU(negative_slope=0.2)) + self.linear1 = nn.Linear(512, 256, bias=False) + self.dropout = nn.Dropout(p=0.5) + + self.fc = nn.Sequential( + nn.Linear(256, 128), + nn.BatchNorm1d(128), + nn.ReLU(), + nn.Dropout(p=0.5), + nn.Linear(128, num_classes), + ) + + def forward(self, x): + # Apply Transform_Net on input point cloud + batch_size = x.size(0) + + trans_matrix = self.transform_net(x) + x = torch.bmm(x, trans_matrix) + x1 = self.edge_conv1(x) + x1 = x1.max(dim=-1, keepdim=False)[0] + # print("x1 shape: ", x1.shape) + x2 = self.edge_conv2(x1) + x2 = x2.max(dim=-1, keepdim=False)[0] + # print("x2 shape: ", x2.shape) + x5 = torch.cat((x1, x2), dim=2) # (batch_size, 64+64+128+256, num_points) + x5 = x5.transpose(2, 1) # (batch_size, num_points, 64+64+128+256) + # print("x5 shape: ", x5.shape) + x_conv = self.conv5(x5) # (batch_size, 1024, num_points) + # print("x_conv shape: ", x_conv.shape) + x6 = F.adaptive_max_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims) + # print("x6 shape: ", x6.shape) + x7 = F.adaptive_avg_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims) + # print("x7 shape: ", x7.shape) + x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2) + # x9 = x8.max(dim=1, keepdim=False)[0] + x10 = self.linear1(x8) + x10 = self.dropout(x10) + x11 = self.fc(x10) + return x11 \ No newline at end of file diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index 8196648..d6f6b02 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -90,7 +90,12 @@ class ShapenetDataDgcnn(object): # get the the number of the class airplane if self.just_one_class: - data = [x for x in data if x.split('/')[-2] == self.cat['Airplane']] + data = [x for x in data if x.split('/')[-2] in [ + self.cat['Airplane'], + self.cat['Lamp'], + self.cat['Chair'], + self.cat['Table'], + ]] return data -- GitLab