From 405cf8243a46ba271a3a88a85bdd856aff9aeba4 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Tue, 4 Apr 2023 15:17:19 +0200
Subject: [PATCH] config corrections

---
 dgcnn/dgcnn_train.py         |  2 +-
 dgcnn/dgcnn_train_pl.py      |  6 +++---
 dgcnn/get_size_of_dataset.py |  6 +++---
 dgcnn/shapenet_data_dgcnn.py | 12 ++++--------
 4 files changed, 11 insertions(+), 15 deletions(-)

diff --git a/dgcnn/dgcnn_train.py b/dgcnn/dgcnn_train.py
index 132972b..c504c6f 100644
--- a/dgcnn/dgcnn_train.py
+++ b/dgcnn/dgcnn_train.py
@@ -25,7 +25,7 @@ def train():
       return_cls_label=True,
       small_data=True,
       small_data_size=300,
-      just_one_class=False,
+      just_four_classes=False,
       split='train',
       norm=True
       )
diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index 02023e7..deeb54e 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn(
       return_cls_label=True,
       small_data=config['data']['small_data'],
       small_data_size=config['data']['small_data_size'],
-      just_one_class=config['data']['just_one_class'],
+      just_four_classes=config['data']['just_four_classes'],
       split='train',
       norm=config['data']['norm'],
       augmnetation=config['data']['augmentation']
@@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
-        just_one_class=config['data']['just_one_class'],
+        just_four_classes=config['data']['just_four_classes'],
         split='test',
         norm=config['data']['norm']
         )
@@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
-        just_one_class=config['data']['just_one_class'],
+        just_four_classes=config['data']['just_four_classes'],
         split='test',
         norm=config['data']['norm']
         )
diff --git a/dgcnn/get_size_of_dataset.py b/dgcnn/get_size_of_dataset.py
index a68f35b..af7cf04 100644
--- a/dgcnn/get_size_of_dataset.py
+++ b/dgcnn/get_size_of_dataset.py
@@ -7,7 +7,7 @@ shapenet_data_train = ShapenetDataDgcnn(
       return_cls_label=True,
       small_data=False,
       small_data_size=1000,
-      just_one_class=True,
+      just_four_classes=True,
       split='train',
       norm=True
       )
@@ -18,7 +18,7 @@ shapenet_data_test = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=False,
         small_data_size=1000,
-        just_one_class=True,
+        just_four_classes=True,
         split='test',
         norm=True
         )
@@ -29,7 +29,7 @@ shapenet_data_val = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=False,
         small_data_size=1000,
-        just_one_class=True,
+        just_four_classes=True,
         split='val',
         norm=True
         )
diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py
index 32ed47a..a847e0f 100644
--- a/dgcnn/shapenet_data_dgcnn.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object):
                  small_data=False,
                  small_data_size=10,
                  return_cls_label=False,
-                 just_one_class=False,
+                 just_four_classes=False,
                  norm=False,
                  augmnetation=False,
                  data_augmentation=False
@@ -29,7 +29,7 @@ class ShapenetDataDgcnn(object):
         self.small_data = small_data
         self.small_data_size = small_data_size
         self.return_cls_label = return_cls_label
-        self.just_one_class = just_one_class
+        self.just_four_classes = just_four_classes
         self.norm = norm
         self.augmnetation = augmnetation
         self.data_augmentation = data_augmentation
@@ -83,8 +83,6 @@ class ShapenetDataDgcnn(object):
         with open(json_file, 'r') as f:
             data = json.load(f)
 
-        print('10 data in the list: ', data[:10])
-
         out_data = []
         for i in range(len(data)):
             out_data.append(os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt')))
@@ -92,7 +90,7 @@ class ShapenetDataDgcnn(object):
         # get one class of data
         # get the the number of the class airplane
 
-        if self.just_one_class:
+        if self.just_four_classes:
             out_data = [x for x in out_data if x.split('/')[-2] in [
                 self.cat['Airplane'],
                 self.cat['Lamp'],
@@ -100,8 +98,6 @@ class ShapenetDataDgcnn(object):
                 self.cat['Table'],
                 ]]
         
-        print('10 data in the out_data list: ', out_data[:10])
-
         return out_data
     
     def get_seg_classes(self, cat):
@@ -217,7 +213,7 @@ class ShapenetDataDgcnn(object):
             class_name = self.test_file_list[index].split('/')[-2]
         elif self.split == 'val':
             class_name = self.val_data_file[index].split('/')[-2]
-            
+
         # apply the mapper
         class_name = self.class_mapper(class_name)
 
-- 
GitLab