From 405cf8243a46ba271a3a88a85bdd856aff9aeba4 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Tue, 4 Apr 2023 15:17:19 +0200 Subject: [PATCH] config corrections --- dgcnn/dgcnn_train.py | 2 +- dgcnn/dgcnn_train_pl.py | 6 +++--- dgcnn/get_size_of_dataset.py | 6 +++--- dgcnn/shapenet_data_dgcnn.py | 12 ++++-------- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/dgcnn/dgcnn_train.py b/dgcnn/dgcnn_train.py index 132972b..c504c6f 100644 --- a/dgcnn/dgcnn_train.py +++ b/dgcnn/dgcnn_train.py @@ -25,7 +25,7 @@ def train(): return_cls_label=True, small_data=True, small_data_size=300, - just_one_class=False, + just_four_classes=False, split='train', norm=True ) diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index 02023e7..deeb54e 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - just_one_class=config['data']['just_one_class'], + just_four_classes=config['data']['just_four_classes'], split='train', norm=config['data']['norm'], augmnetation=config['data']['augmentation'] @@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - just_one_class=config['data']['just_one_class'], + just_four_classes=config['data']['just_four_classes'], split='test', norm=config['data']['norm'] ) @@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn( return_cls_label=True, small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], - just_one_class=config['data']['just_one_class'], + just_four_classes=config['data']['just_four_classes'], split='test', norm=config['data']['norm'] ) diff --git a/dgcnn/get_size_of_dataset.py b/dgcnn/get_size_of_dataset.py index a68f35b..af7cf04 100644 --- a/dgcnn/get_size_of_dataset.py +++ b/dgcnn/get_size_of_dataset.py @@ -7,7 +7,7 @@ shapenet_data_train = ShapenetDataDgcnn( return_cls_label=True, small_data=False, small_data_size=1000, - just_one_class=True, + just_four_classes=True, split='train', norm=True ) @@ -18,7 +18,7 @@ shapenet_data_test = ShapenetDataDgcnn( return_cls_label=True, small_data=False, small_data_size=1000, - just_one_class=True, + just_four_classes=True, split='test', norm=True ) @@ -29,7 +29,7 @@ shapenet_data_val = ShapenetDataDgcnn( return_cls_label=True, small_data=False, small_data_size=1000, - just_one_class=True, + just_four_classes=True, split='val', norm=True ) diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index 32ed47a..a847e0f 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object): small_data=False, small_data_size=10, return_cls_label=False, - just_one_class=False, + just_four_classes=False, norm=False, augmnetation=False, data_augmentation=False @@ -29,7 +29,7 @@ class ShapenetDataDgcnn(object): self.small_data = small_data self.small_data_size = small_data_size self.return_cls_label = return_cls_label - self.just_one_class = just_one_class + self.just_four_classes = just_four_classes self.norm = norm self.augmnetation = augmnetation self.data_augmentation = data_augmentation @@ -83,8 +83,6 @@ class ShapenetDataDgcnn(object): with open(json_file, 'r') as f: data = json.load(f) - print('10 data in the list: ', data[:10]) - out_data = [] for i in range(len(data)): out_data.append(os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt'))) @@ -92,7 +90,7 @@ class ShapenetDataDgcnn(object): # get one class of data # get the the number of the class airplane - if self.just_one_class: + if self.just_four_classes: out_data = [x for x in out_data if x.split('/')[-2] in [ self.cat['Airplane'], self.cat['Lamp'], @@ -100,8 +98,6 @@ class ShapenetDataDgcnn(object): self.cat['Table'], ]] - print('10 data in the out_data list: ', out_data[:10]) - return out_data def get_seg_classes(self, cat): @@ -217,7 +213,7 @@ class ShapenetDataDgcnn(object): class_name = self.test_file_list[index].split('/')[-2] elif self.split == 'val': class_name = self.val_data_file[index].split('/')[-2] - + # apply the mapper class_name = self.class_mapper(class_name) -- GitLab