diff --git a/dgcnn/dgcnn_train.py b/dgcnn/dgcnn_train.py index 132972b8a3d7ed92815240cbea3bbd631fea209d..c504c6fa3405777a0cb6ba3e28af96931ee3057b 100644 --- a/dgcnn/dgcnn_train.py +++ b/dgcnn/dgcnn_train.py @@ -25,7 +25,7 @@ def train(): return_cls_label=True, small_data=True, small_data_size=300, - just_one_class=False, + just_four_classes=False, split='train', norm=True ) diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index 02023e7240a490f1f84f59442411eb2b46e325a5..deeb54e53eeff7dee65fb0d575792b95f8fd4053 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - just_one_class=config['data']['just_one_class'], + just_four_classes=config['data']['just_four_classes'], split='train', norm=config['data']['norm'], augmnetation=config['data']['augmentation'] @@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - just_one_class=config['data']['just_one_class'], + just_four_classes=config['data']['just_four_classes'], split='test', norm=config['data']['norm'] ) @@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - just_one_class=config['data']['just_one_class'], + just_four_classes=config['data']['just_four_classes'], split='test', norm=config['data']['norm'] ) diff --git a/dgcnn/get_size_of_dataset.py b/dgcnn/get_size_of_dataset.py index a68f35be1288369637ab6783f9a701875eb15433..af7cf04e6a8097d410e1cadb63fa824cf029b436 100644 --- a/dgcnn/get_size_of_dataset.py +++ b/dgcnn/get_size_of_dataset.py @@ -7,7 +7,7 @@ shapenet_data_train = ShapenetDataDgcnn( return_cls_label=True, small_data=False, small_data_size=1000, - just_one_class=True, + just_four_classes=True, split='train', norm=True ) @@ -18,7 +18,7 @@ shapenet_data_test = ShapenetDataDgcnn( return_cls_label=True, small_data=False, small_data_size=1000, - just_one_class=True, + just_four_classes=True, split='test', norm=True ) @@ -29,7 +29,7 @@ shapenet_data_val = ShapenetDataDgcnn( return_cls_label=True, small_data=False, small_data_size=1000, - just_one_class=True, + just_four_classes=True, split='val', norm=True ) diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index 32ed47a1cdf706c81b66b7ac41d786511fb96bb2..a847e0faa39911bd16b7977776dd8c4a95f8bff9 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object): small_data=False, small_data_size=10, return_cls_label=False, - just_one_class=False, + just_four_classes=False, norm=False, augmnetation=False, data_augmentation=False @@ -29,7 +29,7 @@ class ShapenetDataDgcnn(object): self.small_data = small_data self.small_data_size = small_data_size self.return_cls_label = return_cls_label - self.just_one_class = just_one_class + self.just_four_classes = just_four_classes self.norm = norm self.augmnetation = augmnetation self.data_augmentation = data_augmentation @@ -83,8 +83,6 @@ class ShapenetDataDgcnn(object): with open(json_file, 'r') as f: data = json.load(f) - print('10 data in the list: ', data[:10]) - out_data = [] for i in range(len(data)): out_data.append(os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt'))) @@ -92,7 +90,7 @@ class ShapenetDataDgcnn(object): # get one class of data # get the the number of the class airplane - if self.just_one_class: + if self.just_four_classes: out_data = [x for x in out_data if x.split('/')[-2] in [ self.cat['Airplane'], self.cat['Lamp'], @@ -100,8 +98,6 @@ class ShapenetDataDgcnn(object): self.cat['Table'], ]] - print('10 data in the out_data list: ', out_data[:10]) - return out_data def get_seg_classes(self, cat): @@ -217,7 +213,7 @@ class ShapenetDataDgcnn(object): class_name = self.test_file_list[index].split('/')[-2] elif self.split == 'val': class_name = self.val_data_file[index].split('/')[-2] - + # apply the mapper class_name = self.class_mapper(class_name)