From e9981ec339f0158f8a3cf2f42b3497b4f6e30c57 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Thu, 23 Mar 2023 13:19:05 +0100
Subject: [PATCH] dgcnn works for classification of shapenet

---
 dgcnn/dataset_run.ipynb                       | 104 ++++++++++++++++++
 dgcnn/dgcnn_main.py                           |   3 +-
 dgcnn/dgcnn_train.py                          |  80 ++++++++++++++
 dgcnn/model.py                                |   6 +-
 ...hapenet_data.py => shapenet_data_dgcnn.py} |  75 ++++++++++++-
 5 files changed, 259 insertions(+), 9 deletions(-)
 create mode 100644 dgcnn/dataset_run.ipynb
 create mode 100644 dgcnn/dgcnn_train.py
 rename dgcnn/{shapenet_data.py => shapenet_data_dgcnn.py} (70%)

diff --git a/dgcnn/dataset_run.ipynb b/dgcnn/dataset_run.ipynb
new file mode 100644
index 0000000..5265d6c
--- /dev/null
+++ b/dgcnn/dataset_run.ipynb
@@ -0,0 +1,104 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'Airplane': '02691156', 'Bag': '02773838', 'Cap': '02954340', 'Car': '02958343', 'Chair': '03001627', 'Earphone': '03261776', 'Guitar': '03467517', 'Knife': '03624134', 'Lamp': '03636649', 'Laptop': '03642806', 'Motorbike': '03790512', 'Mug': '03797390', 'Pistol': '03948459', 'Rocket': '04099429', 'Skateboard': '04225987', 'Table': '04379243'}\n",
+      "None\n"
+     ]
+    }
+   ],
+   "source": [
+    "import os\n",
+    "\n",
+    "root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet'\n",
+    "catfile = os.path.join(root, 'raw',\n",
+    "                       'synsetoffset2category.txt')\n",
+    "\n",
+    "cat = {}\n",
+    "with open(catfile, 'r') as f:\n",
+    "    for line in f:\n",
+    "        ls = line.strip().split()\n",
+    "        cat[ls[0]] = ls[1]\n",
+    "\n",
+    "print(cat)\n",
+    "\n",
+    "def map_class_id_to_numbers(cat):\n",
+    "    for key, value in cat.items():\n",
+    "        if value == 'Airplane':\n",
+    "            return key\n",
+    "\n",
+    "print(map_class_id_to_numbers(cat))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[-0.0148, -1.7624, -2.2409,  0.8483,  0.4169],\n",
+      "        [-0.3802, -0.0168, -1.9958,  0.9351, -1.0900],\n",
+      "        [ 0.3501, -1.3573, -1.8246, -1.2850,  0.2785]], requires_grad=True)\n",
+      "tensor([0, 0, 1])\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Example of target with class indices\n",
+    "import torch\n",
+    "from torch.nn import functional as F\n",
+    "\n",
+    "\n",
+    "input = torch.randn(3, 5, requires_grad=True)\n",
+    "target = torch.randint(5, (3,), dtype=torch.int64)\n",
+    "\n",
+    "print(input)\n",
+    "print(target)\n",
+    "\n",
+    "loss = F.cross_entropy(input, target)\n",
+    "loss.backward()\n",
+    "# Example of target with class probabilities\n",
+    "input = torch.randn(3, 5, requires_grad=True)\n",
+    "target = torch.randn(3, 5).softmax(dim=1)\n",
+    "\n",
+    "print(input)\n",
+    "print(target)\n",
+    "\n",
+    "loss = F.cross_entropy(input, target)\n",
+    "loss.backward()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.10"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/dgcnn/dgcnn_main.py b/dgcnn/dgcnn_main.py
index 7173890..1b11c34 100644
--- a/dgcnn/dgcnn_main.py
+++ b/dgcnn/dgcnn_main.py
@@ -12,10 +12,11 @@ def main():
     dgcnn = DGCNN(num_classes)
     dgcnn.eval()
     # simple test
-    input_tensor = torch.randn(1, 128, 3)
+    input_tensor = torch.randn(2, 128, 3)
     print(input_tensor.shape)
     out = dgcnn(input_tensor)
     print(out.shape)
+    print(out)
 
 if __name__ == '__main__':
     main()
diff --git a/dgcnn/dgcnn_train.py b/dgcnn/dgcnn_train.py
new file mode 100644
index 0000000..08aeea0
--- /dev/null
+++ b/dgcnn/dgcnn_train.py
@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from shapenet_data_dgcnn import ShapenetDataDgcnn
+
+import wandb
+
+# create a wandb run
+wandb.init(project="dgcnn", entity="maciej-wielgosz-nibio")
+
+
+from model import DGCNN
+
+
+def train():
+    num_classes = 16
+    dgcnn = DGCNN(num_classes).cuda()
+    dgcnn.train()
+    
+    # get data 
+    shapenet_data = ShapenetDataDgcnn(
+      root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+      npoints=32,
+      return_cls_label=True,
+      small_data=True,
+      small_data_size=1000,
+      split='train'
+      )
+    
+    # create a dataloader
+    dataloader = torch.utils.data.DataLoader(
+        shapenet_data,
+        batch_size=8,
+        shuffle=False,
+        num_workers=4,
+        drop_last=False
+        )
+    
+    # create a optimizer
+    optimizer = torch.optim.Adam(dgcnn.parameters(), lr=0.0001)
+
+    # create a config wandb
+    wandb.config.update({
+        "batch_size": 8,
+        "learning_rate": 0.0001,
+        "optimizer": "Adam",
+        "loss_function": "cross_entropy"
+        })
+    
+    # add config to wandb
+    wandb.watch(dgcnn)
+
+
+    # train
+    for epoch in range(500):
+        print(f"Epoch: {epoch}")
+        wandb.log({"epoch": epoch})
+        for i, data in enumerate(dataloader, 0):
+            points, _, class_name = data
+            points = points.cuda()
+            class_name = class_name.cuda()
+            optimizer.zero_grad()
+            pred = dgcnn(points)
+            # map the class name to one hot
+            # class_name = torch.nn.functional.one_hot(class_name, num_classes=16)
+            loss = F.cross_entropy(
+                pred, class_name, reduction='mean', ignore_index=255)
+            loss.backward()
+            optimizer.step()
+
+            # print lose every 10 batches
+            if i % 100 == 0:
+                print(loss.item())
+                wandb.log({"loss": loss.item()})
+
+    
+
+if __name__ == '__main__':
+    train()
\ No newline at end of file
diff --git a/dgcnn/model.py b/dgcnn/model.py
index 366dd16..9d3c3b4 100644
--- a/dgcnn/model.py
+++ b/dgcnn/model.py
@@ -3,7 +3,9 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.init as init
 
-#TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
+# TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
+# TODO: There are mode conv layers there
+
 
 class EdgeConv(nn.Module):
     def __init__(self, in_channels, out_channels):
@@ -13,7 +15,7 @@ class EdgeConv(nn.Module):
         self.conv = nn.Sequential(
             nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
             nn.BatchNorm2d(out_channels),
-            nn.ReLU()
+            nn.ReLU() # TODO: replace with leaky relu
         )
 
     def forward(self, x, k=20):
diff --git a/dgcnn/shapenet_data.py b/dgcnn/shapenet_data_dgcnn.py
similarity index 70%
rename from dgcnn/shapenet_data.py
rename to dgcnn/shapenet_data_dgcnn.py
index 7cd999b..275ca4e 100644
--- a/dgcnn/shapenet_data.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -3,9 +3,10 @@ import json
 import os
 
 import numpy as np
+import torch
 
 
-class ShapenetData(object):
+class ShapenetDataDgcnn(object):
     """
     The is the data loader for the ShapeNet dataset. Only for data segmentation, not for classification.
     """
@@ -15,8 +16,9 @@ class ShapenetData(object):
                  split='train',
                  small_data=False,
                  small_data_size=10,
+                 return_cls_label=False,
                  just_one_class=False,
-                 norm=True,
+                 norm=False,
                  data_augmentation=False
                  ) -> None:
         
@@ -25,6 +27,7 @@ class ShapenetData(object):
         self.split = split
         self.small_data = small_data
         self.small_data_size = small_data_size
+        self.return_cls_label = return_cls_label
         self.just_one_class = just_one_class
         self.norm = norm
         self.data_augmentation = data_augmentation
@@ -92,9 +95,45 @@ class ShapenetData(object):
     def get_seg_classes(self, cat):
         return self.seg_classes[cat]
     
+    def class_mapper(self, class_name):
+        if class_name == '02691156':
+            return 0
+        elif class_name == '02773838':
+            return 1
+        elif class_name == '02954340':
+            return 2
+        elif class_name == '02958343':
+            return 3
+        elif class_name == '03001627':
+            return 4
+        elif class_name == '03261776':
+            return 5
+        elif class_name == '03467517':
+            return 6
+        elif class_name == '03624134':
+            return 7
+        elif class_name == '03636649':
+            return 8
+        elif class_name == '03642806':
+            return 9
+        elif class_name == '03790512':
+            return 10
+        elif class_name == '03797390':
+            return 11
+        elif class_name == '03948459':
+            return 12
+        elif class_name == '04099429':
+            return 13
+        elif class_name == '04225987':
+            return 14
+        elif class_name == '04379243':
+            return 15
+        else:
+            raise ValueError('class name not found')
+    
     def get_class_names(self):
         return list(self.cat.values())
-    
+
     def get_all_names_of_classes(self):
         return list(self.cat.keys())
     
@@ -115,6 +154,9 @@ class ShapenetData(object):
         elif self.split == 'val':
             point_set = np.loadtxt(self.val_data_file[index]).astype(np.float32)
 
+        # get labels
+        labels = point_set[:, 6]
+
         # get just the points
         point_set = point_set[:, 0:3]
 
@@ -122,14 +164,32 @@ class ShapenetData(object):
         if self.norm:
             point_set = self.normalize(point_set)
 
-        # choice = np.random.choice(len(point_set), self.npoints, replace=True)
+        choice = np.random.choice(len(point_set), self.npoints, replace=True)
         # chose the first npoints
         choice = np.arange(self.npoints)
 
         point_set = point_set[choice, :]
         point_set = point_set.astype(np.float32)
 
-        return point_set
+        # get the labels
+        labels = labels[choice]
+        labels = labels.astype(np.int64)
+
+        # get the class name
+        class_name = self.train_file_list[index].split('/')[-2]
+        # apply the mapper
+        class_name = self.class_mapper(class_name)
+
+        # convert the class name to a number
+        class_name = np.array(class_name, dtype=np.int64)
+
+        # map to tensor
+        class_name = torch.from_numpy(class_name)
+
+        if self.return_cls_label:
+            return point_set, labels, class_name
+        else:
+            return point_set, labels
 
     def __len__(self):
         if self.split == 'train':
@@ -142,7 +202,10 @@ class ShapenetData(object):
 
 if __name__ == "__main__":
 
-  shapenet_data = ShapenetData(root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', split='train')
+  shapenet_data = ShapenetDataDgcnn(
+      root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+      split='train'
+      )
 #   print(shapenet_data.train_file_list)
   print(shapenet_data.get_seg_classes('Car'))
   print(shapenet_data.get_class_names())
-- 
GitLab