diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index deeb54e53eeff7dee65fb0d575792b95f8fd4053..c76f4c4835bfcab18dcda2d89f0b7bc55d58b802 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -217,7 +217,11 @@ else: ) # Initialize a model -model = DGCNNLightning(num_classes=16) +if config['data']['just_four_classes']: + num_classes = 4 +else: + num_classes = 16 +model = DGCNNLightning(num_classes=num_classes) if config['wandb']['use_wandb']: if config['wandb']['watch_model']: diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index a847e0faa39911bd16b7977776dd8c4a95f8bff9..abf9ba9abe5a31c65d1a3ee89e81a1be89551052 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -138,6 +138,18 @@ class ShapenetDataDgcnn(object): return 15 else: raise ValueError('class name not found') + + def class_mapper_4_classes(self, class_name): + if class_name == self.cat['Airplane']: + return 0 + elif class_name == self.cat['Lamp']: + return 1 + elif class_name == self.cat['Chair']: + return 2 + elif class_name == self.cat['Table']: + return 3 + else: + raise ValueError('class name not found') def get_class_names(self): return list(self.cat.values()) @@ -215,7 +227,10 @@ class ShapenetDataDgcnn(object): class_name = self.val_data_file[index].split('/')[-2] # apply the mapper - class_name = self.class_mapper(class_name) + if self.just_four_classes: + class_name = self.class_mapper_4_classes(class_name) + else: + class_name = self.class_mapper(class_name) # convert the class name to a number class_name = np.array(class_name, dtype=np.int64)