diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index e6ba0c2f8f2877f656b7215782fa81d44e74b055..b891128f1e41f8dd74c0f3848bf8f25fe98e90c0 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -6,6 +6,8 @@ from shapenet_data_dgcnn import ShapenetDataDgcnn import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from model import DGCNN +from pytorch_lightning.strategies import DDPStrategy + @@ -54,7 +56,7 @@ dataloader = torch.utils.data.DataLoader( wandb_logger = WandbLogger(project="dgcnn", name="dgcnn", entity="maciej-wielgosz-nibio") -trainer = pl.Trainer(accelerator="auto", devices=[0], max_epochs=3, logger=wandb_logger, gpus=1) +trainer = pl.Trainer(strategy=DDPStrategy(find_unused_parameters=True), accelerator="auto", devices=[0], max_epochs=3, logger=wandb_logger, gpus=1) # Initialize a model model = DGCNNLightning(num_classes=16)