diff --git a/dgcnn/dgcnn_shape_net_inference.py b/dgcnn/dgcnn_shape_net_inference.py
index 079a7c6750816dabccc4ede3943b92b3d144c196..36dcba693c6aae1a1bd0e88f22ed15611b101fdb 100644
--- a/dgcnn/dgcnn_shape_net_inference.py
+++ b/dgcnn/dgcnn_shape_net_inference.py
@@ -5,14 +5,12 @@ import torch.nn.init as init
 
 from my_models.model_shape_net import DgcnShapeNet as DGCNN
 
-
-
 def main():
-    dgcnn = DGCNN(50)
+    dgcnn = DGCNN(50, 4)
     dgcnn.eval()
     # simple test
     input_tensor = torch.randn(2, 128, 3)
-    label_one_hot = torch.zeros((2, 16))
+    label_one_hot = torch.zeros((2, 4))
     print(input_tensor.shape)
     out = dgcnn(input_tensor, label_one_hot)
     print(out.shape)
diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index e49a9433c0b5ffcb95667ff53e70b217c4bff044..dd1366c46b2403dd38471a2e6de054ee5b580af1 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn(
       return_cls_label=True,
       small_data=config['data']['small_data'],
       small_data_size=config['data']['small_data_size'],
-      num_classes=config['data']['just_four_classes'],
+      num_classes=config['data']['num_classes'],
       split='train',
       norm=config['data']['norm'],
       augmnetation=config['data']['augmentation']
@@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
-        num_classes=config['data']['just_four_classes'],
+        num_classes=config['data']['num_classes'],
         split='test',
         norm=config['data']['norm']
         )
@@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
-        num_classes=config['data']['just_four_classes'],
+        num_classes=config['data']['num_classes'],
         split='test',
         norm=config['data']['norm']
         )
@@ -217,10 +217,8 @@ else:
         )
 
 # Initialize a model
-if config['data']['just_four_classes']:
-    num_classes = 4
-else:
-    num_classes = 16
+num_classes = int(config['data']['num_classes'])
+
 model = DGCNNLightning(num_classes=num_classes)
 
 if config['wandb']['use_wandb']:
diff --git a/dgcnn/dgcnn_train_shape_net.py b/dgcnn/dgcnn_train_shape_net.py
index 9d9b9b96e626f302ae832e399314fc0f723311f4..9aeb3884f3ae603004c5fe9d46580924a2fcb013 100644
--- a/dgcnn/dgcnn_train_shape_net.py
+++ b/dgcnn/dgcnn_train_shape_net.py
@@ -31,12 +31,12 @@ def train():
     # get data 
     shapenet_data = ShapenetDataDgcnn(
       root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
-      npoints=1024,
+      npoints=32,
       return_cls_label=True,
-      small_data=False,
-      small_data_size=300,
-      just_four_classes=False,
-      data_augmentation=True,
+      small_data=True,
+      small_data_size=10,
+      num_classes=True,
+      data_augmentation=False,
       split='train',
       norm=True
       )
@@ -44,7 +44,7 @@ def train():
     # create a dataloader
     dataloader = torch.utils.data.DataLoader(
         shapenet_data,
-        batch_size=4,
+        batch_size=8,
         shuffle=True,
         num_workers=8,
         drop_last=True
@@ -77,6 +77,11 @@ def train():
             # print(f"Batch: {i}")
          
             points, labels, class_name = data
+            # log data to wandb
+            if use_wandb:
+                wandb.log({"points": wandb.Object3D(points[0, :, :].cpu().numpy())})
+                # wandb.log({"labels": labels.cpu().numpy()})
+                # wandb.log({"class_name": class_name.cpu().numpy()})
             label_one_hot = np.zeros((class_name.shape[0], 16))
             for idx in range(class_name.shape[0]):
                 label_one_hot[idx, class_name[idx]] = 1
diff --git a/dgcnn/dgcnn_train_shape_net_cat_1-4.py b/dgcnn/dgcnn_train_shape_net_cat_1-4.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c25b49d3e93d7aa2c111b4e54b6358317e88699
--- /dev/null
+++ b/dgcnn/dgcnn_train_shape_net_cat_1-4.py
@@ -0,0 +1,122 @@
+import random
+import numpy as np
+import torch
+import os
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from matplotlib import pyplot as plt
+from shapenet_data_dgcnn import ShapenetDataDgcnn
+# import IoU from torchmetrics
+from torchmetrics import Accuracy
+from torchmetrics import JaccardIndex as IoU
+from torch.optim.lr_scheduler import StepLR
+
+import wandb
+
+use_wandb = False
+num_classes = 4
+batch_size = 4
+
+# create a wandb run
+if use_wandb:
+    run = wandb.init(project="dgcnn", entity="maciej-wielgosz-nibio")
+
+
+from my_models.model_shape_net import DgcnShapeNet as DGCNN
+
+
+def train():
+    seg_num_all = 50
+    dgcnn = DGCNN(
+        seg_num_all=seg_num_all, 
+        num_classes=num_classes
+        ).cuda()
+    dgcnn.train()
+    
+    # get data 
+    shapenet_data = ShapenetDataDgcnn(
+      root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+      npoints=32,
+      return_cls_label=True,
+      small_data=True,
+      small_data_size=10,
+      num_classes=num_classes,
+      data_augmentation=False,
+      split='train',
+      norm=True
+      )
+    
+    # create a dataloader
+    dataloader = torch.utils.data.DataLoader(
+        shapenet_data,
+        batch_size=batch_size,
+        shuffle=True,
+        num_workers=8,
+        drop_last=True
+        )
+    
+    # create a optimizer
+    optimizer = torch.optim.Adam(dgcnn.parameters(), lr=0.001, weight_decay=1e-4)
+    scheduler = StepLR(optimizer, step_size=20, gamma=0.5)
+    if use_wandb:
+        # create a config wandb
+        wandb.config.update({
+            "batch_size": batch_size,
+            "learning_rate": 0.01,
+            "optimizer": "Adam",
+            "loss_function": "cross_entropy"
+            })
+        
+    # train
+    iou = IoU(num_classes=50, task='multiclass', average='macro').cuda()
+    acc = Accuracy(num_classes=50, compute_on_step=False, dist_sync_on_step=False, task='multiclass').cuda()
+
+    for epoch in range(500):
+        iou.reset()
+        print(f"Epoch: {epoch}")
+        if use_wandb:
+            wandb.log({"epoch": epoch})
+        for i, data in enumerate(dataloader, 0):
+            # print(f"Batch: {i}")
+         
+            points, labels, class_name = data
+            print('class_name', class_name)
+            label_one_hot = np.zeros((class_name.shape[0], num_classes))
+            for idx in range(class_name.shape[0]):
+      
+                label_one_hot[idx, class_name[idx]] = 1
+            label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
+
+            points = points.cuda()
+            labels = labels.cuda()
+            label_one_hot = label_one_hot.cuda()
+
+            optimizer.zero_grad()
+            pred = dgcnn(points, label_one_hot)
+
+            loss = F.cross_entropy(
+                pred, labels, reduction='mean', ignore_index=255)
+            loss.backward()
+            optimizer.step()
+            if optimizer.param_groups[0]['lr'] > 1e-5:
+                scheduler.step()
+
+            # print lose every 10 batches
+            if i % 100 == 0:
+                print(loss.item())
+                pred_softmax = F.softmax(pred, dim=1)
+                pred_argmax = torch.argmax(pred_softmax, dim=1)
+                gt_one_hot = F.one_hot(labels, num_classes=50)
+                pred_one_hot = F.one_hot(pred_argmax, num_classes=50)
+                print("loss : ", loss.item())
+                print("IoU : ",  iou(pred_one_hot, gt_one_hot))   
+                print("Acc : ",  acc(pred_argmax, labels)) 
+                if use_wandb:
+                    wandb.log({"loss": loss.item()})
+                    wandb.log({"iou": iou(pred, labels)})
+                    wandb.log({"acc": acc(pred, labels)})
+
+
+if __name__ == '__main__':
+    train()
\ No newline at end of file
diff --git a/dgcnn/my_models/model_shape_net.py b/dgcnn/my_models/model_shape_net.py
index fc275ac8734d6df2cc9547bc14d46b4fb0b5b8eb..244f4eaf70db12d8a5ee9a9d4b5e55785e62ff37 100644
--- a/dgcnn/my_models/model_shape_net.py
+++ b/dgcnn/my_models/model_shape_net.py
@@ -43,7 +43,7 @@ class DgcnShapeNet(nn.Module):
     
 
 
-    def forward(self, x, class_label):
+    def forward(self, x, class_label_one_hot):
         # Apply Transform_Net on input point cloud
 
         trans_matrix = self.transform_net(x)
@@ -63,10 +63,10 @@ class DgcnShapeNet(nn.Module):
         x = self.conv5(x5)                     # (batch_size, 1024, num_points)   
         x = x.max(dim=-1, keepdim=True)[0]    # (batch_size, 1024)
 
-        class_label = class_label.view(batch_size, -1, 1)           # (batch_size, num_categoties, 1)
-        class_label = self.conv7(class_label)                       # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)
+        class_label_one_hot = class_label_one_hot.view(batch_size, -1, 1)           # (batch_size, num_categoties, 1)
+        class_label_one_hot = self.conv7(class_label_one_hot)                       # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)
 
-        x = torch.cat((x, class_label), dim=1)            # (batch_size, 1088, 1)
+        x = torch.cat((x, class_label_one_hot), dim=1)            # (batch_size, 1088, 1)
         x = x.repeat(1, 1, num_points)          # (batch_size, 1088, num_points)
 
         x = torch.cat((x, x1, x2), dim=1)   # (batch_size, 1088+64*3, num_points)
diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py
index d518a639e1fb2a97cd4ed6379664b5ff2820188f..71fb5f89c8282419bc70f5c1562c32417897a0a5 100644
--- a/dgcnn/shapenet_data_dgcnn.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object):
                  small_data=False,
                  small_data_size=10,
                  return_cls_label=False,
-                 num_classes=1, # None - all classes (50), 1 - one class, 2 - two classes, max 4
+                 num_classes=1, # you can choose 1, 2, 3, 4, or 16
                  norm=False,
                  augmnetation=False,
                  data_augmentation=False
@@ -241,12 +241,15 @@ class ShapenetDataDgcnn(object):
             class_name = self.val_data_file[index].split('/')[-2]
 
         # apply the mapper
-        # if self.num_classes:
-        #     class_name = self.class_mapper_4_classes(class_name)
-        # else:
-        #     class_name = self.class_mapper(class_name)
+        if self.num_classes in range(1, 5):
+            class_name = self.class_mapper_4_classes(class_name)
+        elif self.num_classes == 16:
+            class_name = self.class_mapper(class_name)
+        else:
+            raise ValueError('num_classes not in range, should be in range 1-4 or 16')
+        
 
-        class_name = self.class_mapper(class_name)
+        # class_name = self.class_mapper(class_name)
         
         # convert the class name to a number
         class_name = np.array(class_name, dtype=np.int64)
@@ -254,6 +257,7 @@ class ShapenetDataDgcnn(object):
         # map to tensor
         # class_name = torch.from_numpy(class_name)
 
+
         if self.return_cls_label:
             return point_set, labels, class_name
         else: