diff --git a/dgcnn/config.yaml b/dgcnn/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2dbf98f4fd6e69b96e89ab892f4f7fd245b848f6 --- /dev/null +++ b/dgcnn/config.yaml @@ -0,0 +1,27 @@ +# create a config file +training: + max_epochs: 2 + lr : 0.0001 + batch_size: 4 + shuffle: True + num_workers: 8 + devices: [0] + +model: + num_classes: 16 + +data: + path: /home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet + npoints: 256 + small_data: true + small_data_size: 1000 + just_one_class: false + norm: true + +wandb: + use_wandb: false + project: dgcnn + name: dgcnn + entity: maciej-wielgosz-nibio + + diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index b891128f1e41f8dd74c0f3848bf8f25fe98e90c0..eafe18c51d09526083be67fe820fa4a02f38db0f 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -1,7 +1,6 @@ import torch -import torch.nn as nn import torch.nn.functional as F -import torch.nn.init as init +import yaml from shapenet_data_dgcnn import ShapenetDataDgcnn import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger @@ -9,6 +8,8 @@ from model import DGCNN from pytorch_lightning.strategies import DDPStrategy +with open('config.yaml', 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) class DGCNNLightning(pl.LightningModule): @@ -27,39 +28,54 @@ class DGCNNLightning(pl.LightningModule): return loss def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=0.0001) + optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) return optimizer # get data -shapenet_data = ShapenetDataDgcnn( +shapenet_data_train = ShapenetDataDgcnn( root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', - npoints=256, + npoints=config['data']['npoints'], return_cls_label=True, - small_data=False, - small_data_size=1000, - just_one_class=False, + small_data=config['data']['small_data'], + small_data_size=config['data']['small_data_size'], + just_one_class=config['data']['just_one_class'], split='train', - norm=True + norm=config['data']['norm'] ) # create a dataloader -dataloader = torch.utils.data.DataLoader( - shapenet_data, - batch_size=4, - shuffle=True, - num_workers=8, +dataloader_train = torch.utils.data.DataLoader( + shapenet_data_train, + batch_size=config['training']['batch_size'], + shuffle=config['training']['shuffle'], + num_workers=config['training']['num_workers'], drop_last=True ) - # Initialize a trainer -wandb_logger = WandbLogger(project="dgcnn", name="dgcnn", entity="maciej-wielgosz-nibio") - -trainer = pl.Trainer(strategy=DDPStrategy(find_unused_parameters=True), accelerator="auto", devices=[0], max_epochs=3, logger=wandb_logger, gpus=1) +if config['wandb']['use_wandb']: + wandb_logger = WandbLogger(project="dgcnn", name="dgcnn", entity="maciej-wielgosz-nibio") + trainer = pl.Trainer( + strategy=DDPStrategy(find_unused_parameters=True), + accelerator="auto", + devices=config['wandb']['devices'], + max_epochs=config['training']['max_epochs'], + logger=wandb_logger + ) +else: + trainer = pl.Trainer( + strategy=DDPStrategy(find_unused_parameters=True), + accelerator="auto", + devices=[0], + max_epochs=3 + ) # Initialize a model model = DGCNNLightning(num_classes=16) -wandb_logger.watch(model) + +if config['wandb']['use_wandb']: + wandb_logger.watch(model) # Train the model on gpu -trainer.fit(model, dataloader) +trainer.fit(model, dataloader_train) +