diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index c76f4c4835bfcab18dcda2d89f0b7bc55d58b802..e49a9433c0b5ffcb95667ff53e70b217c4bff044 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - just_four_classes=config['data']['just_four_classes'], + num_classes=config['data']['just_four_classes'], split='train', norm=config['data']['norm'], augmnetation=config['data']['augmentation'] @@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - just_four_classes=config['data']['just_four_classes'], + num_classes=config['data']['just_four_classes'], split='test', norm=config['data']['norm'] ) @@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - just_four_classes=config['data']['just_four_classes'], + num_classes=config['data']['just_four_classes'], split='test', norm=config['data']['norm'] )