diff --git a/dgcnn/dgcnn_train.py b/dgcnn/dgcnn_train.py
index 59f502ea51d0107b5216bdc6d05f7498f03bcefb..132972b8a3d7ed92815240cbea3bbd631fea209d 100644
--- a/dgcnn/dgcnn_train.py
+++ b/dgcnn/dgcnn_train.py
@@ -23,8 +23,8 @@ def train():
       root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
       npoints=256,
       return_cls_label=True,
-      small_data=False,
-      small_data_size=1000,
+      small_data=True,
+      small_data_size=300,
       just_one_class=False,
       split='train',
       norm=True
diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index 6a4bae7b55343cafec5613ed3e37a9e5b741e678..67b04b9dfa781cdc97d87e50528183f47040a169 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -8,12 +8,21 @@ from model import DGCNN
 from torchmetrics import Accuracy, Precision, Recall
 from pytorch_lightning.callbacks import Callback
 from pytorch_lightning.strategies import DDPStrategy
-
+from torch.optim.lr_scheduler import CosineAnnealingLR
 
 class MetricsPrinterCallback(Callback):
+    def on_train_epoch_end(self, trainer, pl_module):
+        metrics = trainer.callback_metrics
+        print(f"Epoch {trainer.current_epoch}:")
+        print(f"  train_loss: {metrics['train_loss']:.4f}")
+        print(f"  train_accuracy: {metrics['train_accurcy_epoch']:.4f}")
+        print(f"  train_precision: {metrics['train_precision_epoch']:.4f}")
+        print(f"  train_recall: {metrics['train_recall_epoch']:.4f}")
+
     def on_validation_end(self, trainer, pl_module):
         metrics = trainer.callback_metrics
         print(f"Epoch {trainer.current_epoch}:")
+        print(f"  val_loss: {metrics['val_loss']:.4f}")
         print(f"  val_accuracy: {metrics['val_acc']:.4f}")
         print(f"  val_precision: {metrics['val_precision']:.4f}")
         print(f"  val_recall: {metrics['val_recall']:.4f}")
@@ -53,8 +62,7 @@ class DGCNNLightning(pl.LightningModule):
     def training_step(self, batch, batch_idx):
         points, _, class_name = batch
         pred = self(points)
-        pred = torch.softmax(pred, dim=1)
-        loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
+        loss = F.cross_entropy(pred, class_name, reduction='mean')
         # metrics
         self.log('train_loss', loss, sync_dist=True)
         self.log('train_acc',  self.train_accuracy(pred, class_name),sync_dist=True)
@@ -76,8 +84,7 @@ class DGCNNLightning(pl.LightningModule):
     def validation_step(self, batch, batch_idx):
         points, _, class_name = batch
         pred = self(points)
-        pred = torch.softmax(pred, dim=1)
-        loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
+        loss = F.cross_entropy(pred, class_name, reduction='mean')
         # update metrics
         self.log('val_loss', loss, sync_dist=True)
         self.log('val_acc',  self.val_accuracy(pred, class_name), sync_dist=True)
@@ -98,28 +105,31 @@ class DGCNNLightning(pl.LightningModule):
     def test_step(self, batch, batch_idx):
         points, _, class_name = batch
         pred = self(points)
-        pred = torch.softmax(pred, dim=1)
-        loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
+        loss = F.cross_entropy(pred, class_name, reduction='mean')
         # update metrics
         self.log('test_loss', loss)
-        self.log('test_acc',  self.test_accuracy(pred, class_name))
-        self.log('test_precision',  self.test_class_precision(pred, class_name))
-        self.log('test_recall',  self.test_recall(pred, class_name))
+        self.log('test_acc',  self.test_accuracy(pred, class_name), sync_dist=True)
+        self.log('test_precision',  self.test_class_precision(pred, class_name), sync_dist=True)
+        self.log('test_recall',  self.test_recall(pred, class_name), sync_dist=True)
         return loss
     
     def test_epoch_end(self, outputs):
         # logs epoch metrics
-        self.log('test_acc_epoch', self.test_accuracy.compute())
-        self.log('test_precision_epoch', self.test_class_precision.compute())
-        self.log('test_recall_epoch', self.test_recall.compute())
+        self.log('test_acc_epoch', self.test_accuracy.compute(), sync_dist=True)
+        self.log('test_precision_epoch', self.test_class_precision.compute(), sync_dist=True)
+        self.log('test_recall_epoch', self.test_recall.compute(), sync_dist=True)
         # reset metrics
         self.test_accuracy.reset()
         self.test_class_precision.reset()
         self.test_recall.reset()
     
     def configure_optimizers(self):
-        optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
-        return optimizer
+        # optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
+        # return optimizer
+        optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9)
+        scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)
+        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"}
+
 
 # get train data 
 shapenet_data_train = ShapenetDataDgcnn(
@@ -161,7 +171,7 @@ shapenet_data_test = ShapenetDataDgcnn(
 dataloader_train = torch.utils.data.DataLoader(
         shapenet_data_train,
         batch_size=config['training']['batch_size'],
-        shuffle=config['training']['shuffle'],
+        shuffle=True,
         num_workers=config['training']['num_workers'],
         drop_last=True
         )
@@ -170,7 +180,7 @@ dataloader_train = torch.utils.data.DataLoader(
 dataloader_val = torch.utils.data.DataLoader(
         shapenet_data_val,
         batch_size=config['training']['batch_size'],
-        shuffle=config['training']['shuffle'],
+        shuffle=False,
         num_workers=config['training']['num_workers'],
         drop_last=True
         )
@@ -179,7 +189,7 @@ dataloader_val = torch.utils.data.DataLoader(
 dataloader_test = torch.utils.data.DataLoader(
         shapenet_data_test,
         batch_size=config['training']['batch_size'],
-        shuffle=config['training']['shuffle'],
+        shuffle=False,
         num_workers=config['training']['num_workers'],
         drop_last=True
         )
diff --git a/dgcnn/model.py b/dgcnn/model.py
index f90b1e2fd2e36b42f8852af028030b5736ec9504..b3dae823428b84f5329fc676fb94deb93c8656b0 100644
--- a/dgcnn/model.py
+++ b/dgcnn/model.py
@@ -141,10 +141,7 @@ class DGCNN(nn.Module):
         x7 = F.adaptive_avg_pool1d(x_conv, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
         # print("x7 shape: ", x7.shape)
         x8 = torch.cat((x6, x7), 1)              # (batch_size, emb_dims*2)
-
-        x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
-        
-
-        x9 = torch.max(x4, dim=1, keepdim=True)[0]
-        x10 = self.fc(x9.squeeze(1))
-        return x10
\ No newline at end of file
+        # x9 = x9.max(dim=1, keepdim=False)[0]
+        x10 = self.linear1(x8)
+        x11 = self.fc(x10)
+        return x11
\ No newline at end of file