diff --git a/dgcnn/config.yaml b/dgcnn/config.yaml
index e36f25c6163b8cdedf985c2137fac046e000eb3e..d616ecea4010f2b0c8837d556f0f0b66ded7aea3 100644
--- a/dgcnn/config.yaml
+++ b/dgcnn/config.yaml
@@ -1,6 +1,6 @@
 # create a config file
 training:
-  max_epochs: 4
+  max_epochs: 2
   lr : 0.0001
   batch_size: 4
   shuffle: True
diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index 82fe4c2c2ef2df08a9436f1b8391d8bad32b4621..3da7d3c04fb90b785dd40551e9fb93b1e64aae86 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -5,9 +5,26 @@ from shapenet_data_dgcnn import ShapenetDataDgcnn
 import pytorch_lightning as pl
 from pytorch_lightning.loggers import WandbLogger
 from model import DGCNN
+from torchmetrics import Accuracy, Precision, Recall
+from pytorch_lightning.callbacks import Callback
 from pytorch_lightning.strategies import DDPStrategy
 
 
+class MetricsPrinterCallback(Callback):
+    def on_validation_end(self, trainer, pl_module):
+        metrics = trainer.callback_metrics
+        print(f"Epoch {trainer.current_epoch}:")
+        print(f"  val_accuracy: {metrics['val_acc']:.4f}")
+        print(f"  val_precision: {metrics['val_precision']:.4f}")
+        print(f"  val_recall: {metrics['val_recall']:.4f}")
+
+    def on_test_end(self, trainer, pl_module):
+        metrics = trainer.callback_metrics
+        print(f"Epoch {trainer.current_epoch}:")
+        print(f"  test_accuracy: {metrics['test_acc']:.4f}")
+        print(f"  test_precision: {metrics['test_precision']:.4f}")
+        print(f"  test_recall: {metrics['test_recall']:.4f}")
+
 with open('config.yaml', 'r') as f:
     config = yaml.load(f, Loader=yaml.FullLoader)
 
@@ -16,6 +33,19 @@ class DGCNNLightning(pl.LightningModule):
     def __init__(self, num_classes):
         super().__init__()
         self.dgcnn = DGCNN(num_classes)
+        # train define metrics
+        self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
+        self.train_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro')
+        self.train_recall = Recall(task='multiclass', num_classes=num_classes, average='macro')
+        # val define metrics
+        self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
+        self.val_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro')
+        self.val_recall = Recall(task='multiclass', num_classes=num_classes, average='macro')
+        # test define metrics
+        self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
+        self.test_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro')
+        self.test_recall = Recall(task='multiclass', num_classes=num_classes, average='macro')
+      
         
     def forward(self, x):
         return self.dgcnn(x)
@@ -23,24 +53,70 @@ class DGCNNLightning(pl.LightningModule):
     def training_step(self, batch, batch_idx):
         points, _, class_name = batch
         pred = self(points)
+        pred = torch.softmax(pred, dim=1)
         loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
-        self.log('train_loss', loss)
+        # metrics
+        self.log('train_loss', loss, sync_dist=True)
+        self.log('train_acc',  self.train_accuracy(pred, class_name),sync_dist=True)
+        self.log('train_precision',  self.train_class_precision(pred, class_name), sync_dist=True)
+        self.log('train_recall',  self.train_recall(pred, class_name), sync_dist=True)
+    
         return loss
     
+    def training_epoch_end(self, outputs):
+        # logs epoch metrics
+        self.log('train_accurcy_epoch', self.train_accuracy.compute(), sync_dist=True)
+        self.log('train_precision_epoch', self.train_class_precision.compute(), sync_dist=True)
+        self.log('train_recall_epoch', self.train_recall.compute(), sync_dist=True)
+        # reset metrics
+        self.train_accuracy.reset()
+        self.train_class_precision.reset()
+        self.train_recall.reset()
+    
     def validation_step(self, batch, batch_idx):
         points, _, class_name = batch
         pred = self(points)
+        pred = torch.softmax(pred, dim=1)
         loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
-        self.log('val_loss', loss)
+        # update metrics
+        self.log('val_loss', loss, sync_dist=True)
+        self.log('val_acc',  self.val_accuracy(pred, class_name), sync_dist=True)
+        self.log('val_precision',  self.val_class_precision(pred, class_name), sync_dist=True)
+        self.log('val_recall',  self.val_recall(pred, class_name), sync_dist=True)
         return loss
     
+    def validation_epoch_end(self, outputs):
+        # logs epoch metrics
+        self.log('val_acc_epoch', self.val_accuracy.compute(), sync_dist=True)
+        self.log('val_precision_epoch', self.val_class_precision.compute(), sync_dist=True)
+        self.log('val_recall_epoch', self.val_recall.compute(), sync_dist=True)
+        # reset metrics
+        self.val_accuracy.reset()
+        self.val_class_precision.reset()
+        self.val_recall.reset()
+    
     def test_step(self, batch, batch_idx):
         points, _, class_name = batch
         pred = self(points)
+        pred = torch.softmax(pred, dim=1)
         loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
+        # update metrics
         self.log('test_loss', loss)
+        self.log('test_acc',  self.test_accuracy(pred, class_name))
+        self.log('test_precision',  self.test_class_precision(pred, class_name))
+        self.log('test_recall',  self.test_recall(pred, class_name))
         return loss
     
+    def test_epoch_end(self, outputs):
+        # logs epoch metrics
+        self.log('test_acc_epoch', self.test_accuracy.compute())
+        self.log('test_precision_epoch', self.test_class_precision.compute())
+        self.log('test_recall_epoch', self.test_recall.compute())
+        # reset metrics
+        self.test_accuracy.reset()
+        self.test_class_precision.reset()
+        self.test_recall.reset()
+    
     def configure_optimizers(self):
         optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
         return optimizer
@@ -120,11 +196,13 @@ if config['wandb']['use_wandb']:
         logger=wandb_logger
         )
 else:
+    metrics_printer_callback = MetricsPrinterCallback()
     trainer = pl.Trainer(
         strategy=DDPStrategy(find_unused_parameters=True), 
         accelerator="auto", 
         devices=config['training']['devices'], 
-        max_epochs=config['training']['max_epochs']
+        max_epochs=config['training']['max_epochs'],
+        callbacks=[metrics_printer_callback]
         )
 
 # Initialize a model
@@ -140,3 +218,5 @@ trainer.fit(model, dataloader_train, dataloader_val)
 trainer.test(model, dataloader_val)
 
 
+
+