From 6c1e540d3661f6acd384006dbecc116687b92bc2 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Wed, 5 Apr 2023 14:53:10 +0200
Subject: [PATCH] update for data loader DGCNN for 4 classes

---
 dgcnn/dgcnn_train_pl.py      |  6 +++++-
 dgcnn/shapenet_data_dgcnn.py | 17 ++++++++++++++++-
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index deeb54e..c76f4c4 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -217,7 +217,11 @@ else:
         )
 
 # Initialize a model
-model = DGCNNLightning(num_classes=16)
+if config['data']['just_four_classes']:
+    num_classes = 4
+else:
+    num_classes = 16
+model = DGCNNLightning(num_classes=num_classes)
 
 if config['wandb']['use_wandb']:
     if config['wandb']['watch_model']:
diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py
index a847e0f..abf9ba9 100644
--- a/dgcnn/shapenet_data_dgcnn.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -138,6 +138,18 @@ class ShapenetDataDgcnn(object):
             return 15
         else:
             raise ValueError('class name not found')
+        
+    def class_mapper_4_classes(self, class_name):
+        if class_name == self.cat['Airplane']:
+            return 0
+        elif class_name == self.cat['Lamp']:
+            return 1
+        elif class_name == self.cat['Chair']:
+            return 2
+        elif class_name == self.cat['Table']:
+            return 3
+        else:
+            raise ValueError('class name not found')
     
     def get_class_names(self):
         return list(self.cat.values())
@@ -215,7 +227,10 @@ class ShapenetDataDgcnn(object):
             class_name = self.val_data_file[index].split('/')[-2]
 
         # apply the mapper
-        class_name = self.class_mapper(class_name)
+        if self.just_four_classes:
+            class_name = self.class_mapper_4_classes(class_name)
+        else:
+            class_name = self.class_mapper(class_name)
 
         # convert the class name to a number
         class_name = np.array(class_name, dtype=np.int64)
-- 
GitLab