diff --git a/dgcnn/config.yaml b/dgcnn/config.yaml index 2dbf98f4fd6e69b96e89ab892f4f7fd245b848f6..e36f25c6163b8cdedf985c2137fac046e000eb3e 100644 --- a/dgcnn/config.yaml +++ b/dgcnn/config.yaml @@ -1,6 +1,6 @@ # create a config file training: - max_epochs: 2 + max_epochs: 4 lr : 0.0001 batch_size: 4 shuffle: True @@ -19,9 +19,10 @@ data: norm: true wandb: - use_wandb: false + use_wandb: true + watch_model : false project: dgcnn - name: dgcnn + name: dgcnn-train-val-test entity: maciej-wielgosz-nibio diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index eafe18c51d09526083be67fe820fa4a02f38db0f..82fe4c2c2ef2df08a9436f1b8391d8bad32b4621 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -27,11 +27,25 @@ class DGCNNLightning(pl.LightningModule): self.log('train_loss', loss) return loss + def validation_step(self, batch, batch_idx): + points, _, class_name = batch + pred = self(points) + loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) + self.log('val_loss', loss) + return loss + + def test_step(self, batch, batch_idx): + points, _, class_name = batch + pred = self(points) + loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) + self.log('test_loss', loss) + return loss + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) return optimizer - # get data +# get train data shapenet_data_train = ShapenetDataDgcnn( root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', npoints=config['data']['npoints'], @@ -43,7 +57,31 @@ shapenet_data_train = ShapenetDataDgcnn( norm=config['data']['norm'] ) - # create a dataloader +# get val data +shapenet_data_val = ShapenetDataDgcnn( + root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', + npoints=config['data']['npoints'], + return_cls_label=True, + small_data=config['data']['small_data'], + small_data_size=config['data']['small_data_size'], + just_one_class=config['data']['just_one_class'], + split='val', + norm=config['data']['norm'] + ) + +# get test data +shapenet_data_test = ShapenetDataDgcnn( + root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', + npoints=config['data']['npoints'], + return_cls_label=True, + small_data=config['data']['small_data'], + small_data_size=config['data']['small_data_size'], + just_one_class=config['data']['just_one_class'], + split='test', + norm=config['data']['norm'] + ) + +# create train dataloader dataloader_train = torch.utils.data.DataLoader( shapenet_data_train, batch_size=config['training']['batch_size'], @@ -52,6 +90,24 @@ dataloader_train = torch.utils.data.DataLoader( drop_last=True ) +# create val dataloader +dataloader_val = torch.utils.data.DataLoader( + shapenet_data_val, + batch_size=config['training']['batch_size'], + shuffle=config['training']['shuffle'], + num_workers=config['training']['num_workers'], + drop_last=True + ) + +# create test dataloader +dataloader_test = torch.utils.data.DataLoader( + shapenet_data_test, + batch_size=config['training']['batch_size'], + shuffle=config['training']['shuffle'], + num_workers=config['training']['num_workers'], + drop_last=True + ) + # Initialize a trainer if config['wandb']['use_wandb']: @@ -59,7 +115,7 @@ if config['wandb']['use_wandb']: trainer = pl.Trainer( strategy=DDPStrategy(find_unused_parameters=True), accelerator="auto", - devices=config['wandb']['devices'], + devices=config['training']['devices'], max_epochs=config['training']['max_epochs'], logger=wandb_logger ) @@ -67,15 +123,20 @@ else: trainer = pl.Trainer( strategy=DDPStrategy(find_unused_parameters=True), accelerator="auto", - devices=[0], - max_epochs=3 + devices=config['training']['devices'], + max_epochs=config['training']['max_epochs'] ) # Initialize a model model = DGCNNLightning(num_classes=16) if config['wandb']['use_wandb']: - wandb_logger.watch(model) -# Train the model on gpu -trainer.fit(model, dataloader_train) + if config['wandb']['watch_model']: + wandb_logger.watch(model) +# Train the model on gpu and validate every epoch +trainer.fit(model, dataloader_train, dataloader_val) + +# Test the model on gpu +trainer.test(model, dataloader_val) +