From 0e069d09fc43d85683d9bda8c75f8acc43351be8 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Fri, 31 Mar 2023 12:34:58 +0200
Subject: [PATCH] dedicated model for classification

---
 dgcnn/dgcnn_train_pl.py      |  16 ++--
 dgcnn/find_missing_data.py   |  20 -----
 dgcnn/get_size_of_dataset.py |  42 +++++++++++
 dgcnn/model.py               |   2 +
 dgcnn/model_class.py         | 137 +++++++++++++++++++++++++++++++++++
 dgcnn/shapenet_data_dgcnn.py |   7 +-
 6 files changed, 195 insertions(+), 29 deletions(-)
 delete mode 100644 dgcnn/find_missing_data.py
 create mode 100644 dgcnn/get_size_of_dataset.py
 create mode 100644 dgcnn/model_class.py

diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index 65b4fb5..6db78aa 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -4,7 +4,7 @@ import yaml
 from shapenet_data_dgcnn import ShapenetDataDgcnn
 import pytorch_lightning as pl
 from pytorch_lightning.loggers import WandbLogger
-from model import DGCNN
+from model_class import DgcnnClass
 from torchmetrics import Accuracy, Precision, Recall
 from pytorch_lightning.callbacks import Callback
 from pytorch_lightning.strategies import DDPStrategy
@@ -41,7 +41,7 @@ with open('config.yaml', 'r') as f:
 class DGCNNLightning(pl.LightningModule):
     def __init__(self, num_classes):
         super().__init__()
-        self.dgcnn = DGCNN(num_classes)
+        self.dgcnn = DgcnnClass(num_classes)
         # train define metrics
         self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
         self.train_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro')
@@ -124,11 +124,11 @@ class DGCNNLightning(pl.LightningModule):
         self.test_recall.reset()
     
     def configure_optimizers(self):
-        # optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
-        # return optimizer
-        optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9)
-        scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)
-        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"}
+        optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
+        return optimizer
+        # optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9)
+        # scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)
+        # return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"}
 
 
 # get train data 
@@ -152,7 +152,7 @@ shapenet_data_val = ShapenetDataDgcnn(
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
         just_one_class=config['data']['just_one_class'],
-        split='val',
+        split='train',
         norm=config['data']['norm']
         )
 
diff --git a/dgcnn/find_missing_data.py b/dgcnn/find_missing_data.py
deleted file mode 100644
index d9bcc64..0000000
--- a/dgcnn/find_missing_data.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from tqdm import tqdm
-from shapenet_data_dgcnn import ShapenetDataDgcnn
-
-shapenet_data = ShapenetDataDgcnn(
-      root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
-      npoints=128,
-      return_cls_label=True,
-      small_data=False,
-      small_data_size=1000,
-      just_one_class=False,
-      split='train',
-      norm=True
-      )
-
-# read the data one by one and check if exists
-
-for i in tqdm(range(len(shapenet_data))):
-    data = shapenet_data[i]
-    if data[0].shape[0] != 128:
-        print(f"Data is None: {i}")
\ No newline at end of file
diff --git a/dgcnn/get_size_of_dataset.py b/dgcnn/get_size_of_dataset.py
new file mode 100644
index 0000000..a68f35b
--- /dev/null
+++ b/dgcnn/get_size_of_dataset.py
@@ -0,0 +1,42 @@
+from tqdm import tqdm
+from shapenet_data_dgcnn import ShapenetDataDgcnn
+
+shapenet_data_train = ShapenetDataDgcnn(
+      root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+      npoints=128,
+      return_cls_label=True,
+      small_data=False,
+      small_data_size=1000,
+      just_one_class=True,
+      split='train',
+      norm=True
+      )
+
+shapenet_data_test = ShapenetDataDgcnn(
+        root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+        npoints=128,
+        return_cls_label=True,
+        small_data=False,
+        small_data_size=1000,
+        just_one_class=True,
+        split='test',
+        norm=True
+        )
+
+shapenet_data_val = ShapenetDataDgcnn(
+        root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+        npoints=128,
+        return_cls_label=True,
+        small_data=False,
+        small_data_size=1000,
+        just_one_class=True,
+        split='val',
+        norm=True
+        )
+
+# print the length of the data
+print(f"Train: {len(shapenet_data_train)}")
+print(f"Test: {len(shapenet_data_test)}")
+print(f"Val: {len(shapenet_data_val)}")
+
+
diff --git a/dgcnn/model.py b/dgcnn/model.py
index b3dae82..89c2781 100644
--- a/dgcnn/model.py
+++ b/dgcnn/model.py
@@ -100,6 +100,7 @@ class DGCNN(nn.Module):
                                    self.bn5,
                                    nn.LeakyReLU(negative_slope=0.2))
         self.linear1 = nn.Linear(2048, 512, bias=False)
+        self.dropout = nn.Dropout(p=0.5)
 
         self.fc = nn.Sequential(
             nn.Linear(512, 256),
@@ -143,5 +144,6 @@ class DGCNN(nn.Module):
         x8 = torch.cat((x6, x7), 1)              # (batch_size, emb_dims*2)
         # x9 = x9.max(dim=1, keepdim=False)[0]
         x10 = self.linear1(x8)
+        x10 = self.dropout(x10)
         x11 = self.fc(x10)
         return x11
\ No newline at end of file
diff --git a/dgcnn/model_class.py b/dgcnn/model_class.py
new file mode 100644
index 0000000..10ccfce
--- /dev/null
+++ b/dgcnn/model_class.py
@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+
+# TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
+# TODO: There are mode conv layers there
+
+
+class EdgeConv(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(EdgeConv, self).__init__()
+        self.in_channels = in_channels
+        self.conv = nn.Sequential(
+            nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU()
+        )
+
+    def forward(self, x, k=20):
+        #batch_size, num_points, in_channels
+
+        batch_size, num_points, feature_dim = x.shape
+        x = x.view(batch_size, num_points,feature_dim )
+        knn_indices = self.knn(x, k)
+        knn_gathered = self.gather_neighbors(x, knn_indices)
+        edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
+        edge_features = edge_features.view(batch_size, 2*feature_dim, num_points, k)
+        return self.conv(edge_features).transpose(2, 1)
+
+    @staticmethod
+    def knn(x, k):
+        """Find the indices of the k nearest neighbors for each point in the input tensor."""
+        batch_size, num_points, _ = x.shape
+        x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
+        x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
+        distances = torch.norm(x_expanded - x_transposed, dim=-1)
+        _, indices = distances.topk(k=k, dim=-1, largest=False)
+        return indices
+    
+    def gather_neighbors(self, x, knn_indices):
+        batch_size, num_points, _ = x.shape
+        _, _, k = knn_indices.shape
+        x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
+        neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, self.in_channels))
+        return neighbors
+    
+    
+class Transform_Net(nn.Module):
+    def __init__(self, k=3):
+        super(Transform_Net, self).__init__()
+        self.k = k
+
+        self.conv1 = nn.Conv1d(k, 64, 1)
+        self.conv2 = nn.Conv1d(64, 128, 1)
+        self.conv3 = nn.Conv1d(128, 256, 1)
+
+   
+        self.fc3 = nn.Linear(256, k*k)
+
+        self.bn1 = nn.BatchNorm1d(64)
+        self.bn2 = nn.BatchNorm1d(128)
+        self.bn3 = nn.BatchNorm1d(256)
+        self.bn4 = nn.BatchNorm1d(128)
+        self.bn5 = nn.BatchNorm1d(64)
+        self.dropout_0 = nn.Dropout(p=0.5)
+        self.dropout_1 = nn.Dropout(p=0.5)
+        self.dropout_2 = nn.Dropout(p=0.5)
+
+    def forward(self, x):
+        # Input shape: (batch_size, k, num_points)
+        x = x.transpose(2, 1)
+        x = F.relu(self.bn1(self.conv1(x)))
+        x = F.relu(self.bn2(self.conv2(x)))
+        x = F.relu(self.bn3(self.conv3(x)))
+
+        x = torch.max(x, 2, keepdim=True)[0]
+        x = x.view(-1, 256)
+
+        x = self.fc3(x)
+        x = self.dropout_2(x)
+
+        identity = torch.eye(self.k, dtype=x.dtype, device=x.device)
+        transform = x.view(-1, self.k, self.k) + identity
+
+        transform = transform.transpose(2, 1)
+
+        return transform
+
+class DgcnnClass(nn.Module):
+    def __init__(self, num_classes):
+        super(DgcnnClass, self).__init__()
+        self.transform_net = Transform_Net()
+        self.edge_conv1 = EdgeConv(3, 64)
+        self.edge_conv2 = EdgeConv(64, 128)
+        self.bn5 = nn.BatchNorm1d(256)
+        self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False),
+                                   self.bn5,
+                                   nn.LeakyReLU(negative_slope=0.2))
+        self.linear1 = nn.Linear(512, 256, bias=False)
+        self.dropout = nn.Dropout(p=0.5)
+
+        self.fc = nn.Sequential(
+            nn.Linear(256, 128),
+            nn.BatchNorm1d(128),
+            nn.ReLU(),
+            nn.Dropout(p=0.5),
+            nn.Linear(128, num_classes),
+        )
+
+    def forward(self, x):
+        # Apply Transform_Net on input point cloud
+        batch_size = x.size(0)
+
+        trans_matrix = self.transform_net(x)
+        x = torch.bmm(x, trans_matrix)
+        x1 = self.edge_conv1(x)
+        x1 = x1.max(dim=-1, keepdim=False)[0] 
+        # print("x1 shape: ", x1.shape)
+        x2 = self.edge_conv2(x1)
+        x2 = x2.max(dim=-1, keepdim=False)[0]
+        # print("x2 shape: ", x2.shape)
+        x5 = torch.cat((x1, x2), dim=2)  # (batch_size, 64+64+128+256, num_points)
+        x5 = x5.transpose(2, 1)                 # (batch_size, num_points, 64+64+128+256)
+        # print("x5 shape: ", x5.shape)
+        x_conv = self.conv5(x5)                     # (batch_size, 1024, num_points)   
+        # print("x_conv shape: ", x_conv.shape)
+        x6 = F.adaptive_max_pool1d(x_conv, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
+        # print("x6 shape: ", x6.shape)
+        x7 = F.adaptive_avg_pool1d(x_conv, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
+        # print("x7 shape: ", x7.shape)
+        x8 = torch.cat((x6, x7), 1)              # (batch_size, emb_dims*2)
+        # x9 = x8.max(dim=1, keepdim=False)[0]
+        x10 = self.linear1(x8)
+        x10 = self.dropout(x10)
+        x11 = self.fc(x10)
+        return x11
\ No newline at end of file
diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py
index 8196648..d6f6b02 100644
--- a/dgcnn/shapenet_data_dgcnn.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -90,7 +90,12 @@ class ShapenetDataDgcnn(object):
         # get the the number of the class airplane
 
         if self.just_one_class:
-            data = [x for x in data if x.split('/')[-2] == self.cat['Airplane']]
+            data = [x for x in data if x.split('/')[-2] in [
+                self.cat['Airplane'],
+                self.cat['Lamp'],
+                self.cat['Chair'],
+                self.cat['Table'],
+                ]]
 
         return data
     
-- 
GitLab