diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index deeb54e53eeff7dee65fb0d575792b95f8fd4053..c76f4c4835bfcab18dcda2d89f0b7bc55d58b802 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -217,7 +217,11 @@ else:
         )
 
 # Initialize a model
-model = DGCNNLightning(num_classes=16)
+if config['data']['just_four_classes']:
+    num_classes = 4
+else:
+    num_classes = 16
+model = DGCNNLightning(num_classes=num_classes)
 
 if config['wandb']['use_wandb']:
     if config['wandb']['watch_model']:
diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py
index a847e0faa39911bd16b7977776dd8c4a95f8bff9..abf9ba9abe5a31c65d1a3ee89e81a1be89551052 100644
--- a/dgcnn/shapenet_data_dgcnn.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -138,6 +138,18 @@ class ShapenetDataDgcnn(object):
             return 15
         else:
             raise ValueError('class name not found')
+        
+    def class_mapper_4_classes(self, class_name):
+        if class_name == self.cat['Airplane']:
+            return 0
+        elif class_name == self.cat['Lamp']:
+            return 1
+        elif class_name == self.cat['Chair']:
+            return 2
+        elif class_name == self.cat['Table']:
+            return 3
+        else:
+            raise ValueError('class name not found')
     
     def get_class_names(self):
         return list(self.cat.values())
@@ -215,7 +227,10 @@ class ShapenetDataDgcnn(object):
             class_name = self.val_data_file[index].split('/')[-2]
 
         # apply the mapper
-        class_name = self.class_mapper(class_name)
+        if self.just_four_classes:
+            class_name = self.class_mapper_4_classes(class_name)
+        else:
+            class_name = self.class_mapper(class_name)
 
         # convert the class name to a number
         class_name = np.array(class_name, dtype=np.int64)