diff --git a/dgcnn/dgcnn_train.py b/dgcnn/dgcnn_train.py
index 132972b8a3d7ed92815240cbea3bbd631fea209d..c504c6fa3405777a0cb6ba3e28af96931ee3057b 100644
--- a/dgcnn/dgcnn_train.py
+++ b/dgcnn/dgcnn_train.py
@@ -25,7 +25,7 @@ def train():
       return_cls_label=True,
       small_data=True,
       small_data_size=300,
-      just_one_class=False,
+      just_four_classes=False,
       split='train',
       norm=True
       )
diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index 02023e7240a490f1f84f59442411eb2b46e325a5..deeb54e53eeff7dee65fb0d575792b95f8fd4053 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn(
       return_cls_label=True,
       small_data=config['data']['small_data'],
       small_data_size=config['data']['small_data_size'],
-      just_one_class=config['data']['just_one_class'],
+      just_four_classes=config['data']['just_four_classes'],
       split='train',
       norm=config['data']['norm'],
       augmnetation=config['data']['augmentation']
@@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
-        just_one_class=config['data']['just_one_class'],
+        just_four_classes=config['data']['just_four_classes'],
         split='test',
         norm=config['data']['norm']
         )
@@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
-        just_one_class=config['data']['just_one_class'],
+        just_four_classes=config['data']['just_four_classes'],
         split='test',
         norm=config['data']['norm']
         )
diff --git a/dgcnn/get_size_of_dataset.py b/dgcnn/get_size_of_dataset.py
index a68f35be1288369637ab6783f9a701875eb15433..af7cf04e6a8097d410e1cadb63fa824cf029b436 100644
--- a/dgcnn/get_size_of_dataset.py
+++ b/dgcnn/get_size_of_dataset.py
@@ -7,7 +7,7 @@ shapenet_data_train = ShapenetDataDgcnn(
       return_cls_label=True,
       small_data=False,
       small_data_size=1000,
-      just_one_class=True,
+      just_four_classes=True,
       split='train',
       norm=True
       )
@@ -18,7 +18,7 @@ shapenet_data_test = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=False,
         small_data_size=1000,
-        just_one_class=True,
+        just_four_classes=True,
         split='test',
         norm=True
         )
@@ -29,7 +29,7 @@ shapenet_data_val = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=False,
         small_data_size=1000,
-        just_one_class=True,
+        just_four_classes=True,
         split='val',
         norm=True
         )
diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py
index 32ed47a1cdf706c81b66b7ac41d786511fb96bb2..a847e0faa39911bd16b7977776dd8c4a95f8bff9 100644
--- a/dgcnn/shapenet_data_dgcnn.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object):
                  small_data=False,
                  small_data_size=10,
                  return_cls_label=False,
-                 just_one_class=False,
+                 just_four_classes=False,
                  norm=False,
                  augmnetation=False,
                  data_augmentation=False
@@ -29,7 +29,7 @@ class ShapenetDataDgcnn(object):
         self.small_data = small_data
         self.small_data_size = small_data_size
         self.return_cls_label = return_cls_label
-        self.just_one_class = just_one_class
+        self.just_four_classes = just_four_classes
         self.norm = norm
         self.augmnetation = augmnetation
         self.data_augmentation = data_augmentation
@@ -83,8 +83,6 @@ class ShapenetDataDgcnn(object):
         with open(json_file, 'r') as f:
             data = json.load(f)
 
-        print('10 data in the list: ', data[:10])
-
         out_data = []
         for i in range(len(data)):
             out_data.append(os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt')))
@@ -92,7 +90,7 @@ class ShapenetDataDgcnn(object):
         # get one class of data
         # get the the number of the class airplane
 
-        if self.just_one_class:
+        if self.just_four_classes:
             out_data = [x for x in out_data if x.split('/')[-2] in [
                 self.cat['Airplane'],
                 self.cat['Lamp'],
@@ -100,8 +98,6 @@ class ShapenetDataDgcnn(object):
                 self.cat['Table'],
                 ]]
         
-        print('10 data in the out_data list: ', out_data[:10])
-
         return out_data
     
     def get_seg_classes(self, cat):
@@ -217,7 +213,7 @@ class ShapenetDataDgcnn(object):
             class_name = self.test_file_list[index].split('/')[-2]
         elif self.split == 'val':
             class_name = self.val_data_file[index].split('/')[-2]
-            
+
         # apply the mapper
         class_name = self.class_mapper(class_name)