diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index c76f4c4835bfcab18dcda2d89f0b7bc55d58b802..e49a9433c0b5ffcb95667ff53e70b217c4bff044 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn(
       return_cls_label=True,
       small_data=config['data']['small_data'],
       small_data_size=config['data']['small_data_size'],
-      just_four_classes=config['data']['just_four_classes'],
+      num_classes=config['data']['just_four_classes'],
       split='train',
       norm=config['data']['norm'],
       augmnetation=config['data']['augmentation']
@@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
-        just_four_classes=config['data']['just_four_classes'],
+        num_classes=config['data']['just_four_classes'],
         split='test',
         norm=config['data']['norm']
         )
@@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
-        just_four_classes=config['data']['just_four_classes'],
+        num_classes=config['data']['just_four_classes'],
         split='test',
         norm=config['data']['norm']
         )