diff --git a/dgcnn/config.yaml b/dgcnn/config.yaml
index 2dbf98f4fd6e69b96e89ab892f4f7fd245b848f6..e36f25c6163b8cdedf985c2137fac046e000eb3e 100644
--- a/dgcnn/config.yaml
+++ b/dgcnn/config.yaml
@@ -1,6 +1,6 @@
 # create a config file
 training:
-  max_epochs: 2
+  max_epochs: 4
   lr : 0.0001
   batch_size: 4
   shuffle: True
@@ -19,9 +19,10 @@ data:
   norm: true
 
 wandb:
-  use_wandb: false
+  use_wandb: true
+  watch_model : false
   project: dgcnn
-  name: dgcnn
+  name: dgcnn-train-val-test
   entity: maciej-wielgosz-nibio
 
 
diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index eafe18c51d09526083be67fe820fa4a02f38db0f..82fe4c2c2ef2df08a9436f1b8391d8bad32b4621 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -27,11 +27,25 @@ class DGCNNLightning(pl.LightningModule):
         self.log('train_loss', loss)
         return loss
     
+    def validation_step(self, batch, batch_idx):
+        points, _, class_name = batch
+        pred = self(points)
+        loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
+        self.log('val_loss', loss)
+        return loss
+    
+    def test_step(self, batch, batch_idx):
+        points, _, class_name = batch
+        pred = self(points)
+        loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
+        self.log('test_loss', loss)
+        return loss
+    
     def configure_optimizers(self):
         optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
         return optimizer
 
- # get data 
+# get train data 
 shapenet_data_train = ShapenetDataDgcnn(
       root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
       npoints=config['data']['npoints'],
@@ -43,7 +57,31 @@ shapenet_data_train = ShapenetDataDgcnn(
       norm=config['data']['norm']
       )
 
-  # create a dataloader
+# get val data
+shapenet_data_val = ShapenetDataDgcnn(
+        root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
+        npoints=config['data']['npoints'],
+        return_cls_label=True,
+        small_data=config['data']['small_data'],
+        small_data_size=config['data']['small_data_size'],
+        just_one_class=config['data']['just_one_class'],
+        split='val',
+        norm=config['data']['norm']
+        )
+
+# get test data
+shapenet_data_test = ShapenetDataDgcnn(
+        root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
+        npoints=config['data']['npoints'],
+        return_cls_label=True,
+        small_data=config['data']['small_data'],
+        small_data_size=config['data']['small_data_size'],
+        just_one_class=config['data']['just_one_class'],
+        split='test',
+        norm=config['data']['norm']
+        )
+
+# create train dataloader
 dataloader_train = torch.utils.data.DataLoader(
         shapenet_data_train,
         batch_size=config['training']['batch_size'],
@@ -52,6 +90,24 @@ dataloader_train = torch.utils.data.DataLoader(
         drop_last=True
         )
 
+# create val dataloader
+dataloader_val = torch.utils.data.DataLoader(
+        shapenet_data_val,
+        batch_size=config['training']['batch_size'],
+        shuffle=config['training']['shuffle'],
+        num_workers=config['training']['num_workers'],
+        drop_last=True
+        )
+
+# create test dataloader
+dataloader_test = torch.utils.data.DataLoader(
+        shapenet_data_test,
+        batch_size=config['training']['batch_size'],
+        shuffle=config['training']['shuffle'],
+        num_workers=config['training']['num_workers'],
+        drop_last=True
+        )
+
 # Initialize a trainer
 
 if config['wandb']['use_wandb']:
@@ -59,7 +115,7 @@ if config['wandb']['use_wandb']:
     trainer = pl.Trainer(
         strategy=DDPStrategy(find_unused_parameters=True), 
         accelerator="auto", 
-        devices=config['wandb']['devices'], 
+        devices=config['training']['devices'], 
         max_epochs=config['training']['max_epochs'], 
         logger=wandb_logger
         )
@@ -67,15 +123,20 @@ else:
     trainer = pl.Trainer(
         strategy=DDPStrategy(find_unused_parameters=True), 
         accelerator="auto", 
-        devices=[0], 
-        max_epochs=3
+        devices=config['training']['devices'], 
+        max_epochs=config['training']['max_epochs']
         )
 
 # Initialize a model
 model = DGCNNLightning(num_classes=16)
 
 if config['wandb']['use_wandb']:
-    wandb_logger.watch(model)
-# Train the model on gpu
-trainer.fit(model, dataloader_train)
+    if config['wandb']['watch_model']:
+        wandb_logger.watch(model)
+# Train the model on gpu and validate every epoch
+trainer.fit(model, dataloader_train, dataloader_val)
+
+# Test the model on gpu
+trainer.test(model, dataloader_val)
+