diff --git a/train_partseg_forest.py b/train_partseg_forest.py
index 9bb4a8561e44b1fc8521bc04cf8ec3fb3e510d31..4265d52430488b905d044388b47a5dcc4b00e062 100644
--- a/train_partseg_forest.py
+++ b/train_partseg_forest.py
@@ -10,12 +10,12 @@ import logging
 import sys
 import importlib
 import shutil
+
+import wandb
 import provider
 import numpy as np
 
-from pathlib import Path
 from tqdm import tqdm
-from dataset import PartNormalDataset
 from nibio_transformer_semantic.dataset import Dataset
 import hydra
 import omegaconf
@@ -44,17 +44,14 @@ def to_categorical(y, num_classes):
 def main(args):
     omegaconf.OmegaConf.set_struct(args, False)
 
+    conf = omegaconf.OmegaConf.to_container(args, resolve=True)
+
+    wandb.init(project="forest-point-transformer", entity="maciej-wielgosz-nibio", config=conf)
+
     '''HYPER PARAMETER'''
     os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
-
-    # print('GPU available: {}'.format(torch.cuda.is_available()))
     logger = logging.getLogger(__name__)
 
-    # print(args.pretty())
-
-    # use pretty print to print the config
-    
-
     train_dataset = hydra.utils.to_absolute_path('data/forest_txt/train_txt/')
     test_dataset = hydra.utils.to_absolute_path('data/forest_txt/test_txt/')
 
@@ -74,7 +71,6 @@ def main(args):
     '''MODEL LOADING'''
     args.input_dim = (6 if args.normal else 3) + 16
     args.num_class = 4
-    num_category = 1
     num_part = args.num_class
     shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
 
@@ -143,20 +139,8 @@ def main(args):
 
             points, label = points.float().cuda(), label.long().cuda()
             optimizer.zero_grad()
-
-            # print("points shape ..: ", points.shape)
-            # print("label shape ...: ", label.shape)
-            # print(" points.shape[1] : ", points.shape[1])
-            # print("to_categorical(label, num_category): ", to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).shape)
-
-            # print("to_categorical(label, num_category).repeat(1, points.shape[1], 1): ", to_categorical(torch.tensor(1).cuda(), num_category).repeat(1, points.shape[1], 1).shape)
-            # print("input shape: ", torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).repeat(1, points.shape[1], 1)], -1).shape)
-
             seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
-            # seg_pred = classifier(torch.cat([points, to_categorical(label, num_category)], -1))
-
-
-          
+     
             seg_pred = seg_pred.contiguous().view(-1, num_part)
             target = label.view(-1, 1)[:, 0]
             pred_choice = seg_pred.data.max(1)[1]
@@ -164,11 +148,18 @@ def main(args):
             correct = pred_choice.eq(target.data).cpu().sum()
             mean_correct.append(correct.item() / (args.batch_size * args.num_point))
             loss = criterion(seg_pred, target)
+            # add loss to wandb
+            wandb.log({'train_loss': loss.item()})
             loss.backward()
             optimizer.step()
 
+
         train_instance_acc = np.mean(mean_correct)
         logger.info('Train accuracy is: %.5f' % train_instance_acc)
+        # add train accuracy to wandb
+        wandb.log({'train_accuracy': train_instance_acc})
+        # add epoch to wandb
+        wandb.log({'epoch': epoch})
 
         with torch.no_grad():
             test_metrics = {}
@@ -264,8 +255,15 @@ def main(args):
         logger.info('Best accuracy is: %.5f' % best_acc)
         logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
         logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
+        wandb.log({'test_accuracy': test_metrics['accuracy']})
+        wandb.log({'test_class_avg_iou': test_metrics['class_avg_iou']})
+        wandb.log({'test_inctance_avg_iou': test_metrics['inctance_avg_iou']})
+        wandb.log({'best_accuracy': best_acc})
+        wandb.log({'best_class_avg_iou': best_class_avg_iou})
+        wandb.log({'best_inctance_avg_iou': best_inctance_avg_iou})
         global_epoch += 1
 
 
+
 if __name__ == '__main__':
     main()
\ No newline at end of file