From 6c1e540d3661f6acd384006dbecc116687b92bc2 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Wed, 5 Apr 2023 14:53:10 +0200 Subject: [PATCH] update for data loader DGCNN for 4 classes --- dgcnn/dgcnn_train_pl.py | 6 +++++- dgcnn/shapenet_data_dgcnn.py | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index deeb54e..c76f4c4 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -217,7 +217,11 @@ else: ) # Initialize a model -model = DGCNNLightning(num_classes=16) +if config['data']['just_four_classes']: + num_classes = 4 +else: + num_classes = 16 +model = DGCNNLightning(num_classes=num_classes) if config['wandb']['use_wandb']: if config['wandb']['watch_model']: diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index a847e0f..abf9ba9 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -138,6 +138,18 @@ class ShapenetDataDgcnn(object): return 15 else: raise ValueError('class name not found') + + def class_mapper_4_classes(self, class_name): + if class_name == self.cat['Airplane']: + return 0 + elif class_name == self.cat['Lamp']: + return 1 + elif class_name == self.cat['Chair']: + return 2 + elif class_name == self.cat['Table']: + return 3 + else: + raise ValueError('class name not found') def get_class_names(self): return list(self.cat.values()) @@ -215,7 +227,10 @@ class ShapenetDataDgcnn(object): class_name = self.val_data_file[index].split('/')[-2] # apply the mapper - class_name = self.class_mapper(class_name) + if self.just_four_classes: + class_name = self.class_mapper_4_classes(class_name) + else: + class_name = self.class_mapper(class_name) # convert the class name to a number class_name = np.array(class_name, dtype=np.int64) -- GitLab