diff --git a/.gitignore b/.gitignore
index 29dc5f63fdf56545bca96f32d2f97aa75d6ab191..0f2e01bf277d36d2bc642b01704ebc3491e67b53 100644
--- a/.gitignore
+++ b/.gitignore
@@ -130,5 +130,7 @@ data/
 shapenet_part_seg_hdf5_data/
 ShapeNet/
 maciek_data
+nibio_data
+nibio_data_no_commas
 
 ```
diff --git a/PyG_implementation/my_data_loader.py b/PyG_implementation/my_data_loader.py
index cb1c7846ce5e441fe35ff8dd8d951e49eebe11ee..37ce63dac57d8bf591e8cabaaf19a33234eb197a 100644
--- a/PyG_implementation/my_data_loader.py
+++ b/PyG_implementation/my_data_loader.py
@@ -19,12 +19,14 @@ class MyData(InMemoryDataset):
 
     def __init__(self,
                  root: str,
+                 label_location: int = -1,
                  include_normals: bool = True,
                  split: str = 'trainval',
                  split_ratio: tuple = (0.7, 0.15, 0.15),
                  transform: Optional[Callable] = None,
                  pre_transform: Optional[Callable] = None,
                  pre_filter: Optional[Callable] = None):
+        self.label_location = label_location
         self.split_ratio = split_ratio
         super().__init__(root, transform, pre_transform, pre_filter)
 
@@ -87,7 +89,7 @@ class MyData(InMemoryDataset):
             data = read_txt_array(osp.join(self.raw_dir, name))
             pos = data[:, :3]
             x = data[:, 3:6]
-            y = data[:, -1].type(torch.long)
+            y = data[:, self.label_location].type(torch.long)
             category = torch.tensor(0, dtype=torch.long)  # there is only one category !!
             data = Data(pos=pos, x=x, y=y, category=category)
             if self.pre_filter is not None and not self.pre_filter(data):
diff --git a/PyG_implementation/pyg_implementaion_main_my_data_loader.py b/PyG_implementation/pyg_implementaion_main_my_data_loader.py
index 63aed1992df033ba93bc4d5e6952551919c12322..564fdccd513684c84d4057699f29f2e91a1cded1 100644
--- a/PyG_implementation/pyg_implementaion_main_my_data_loader.py
+++ b/PyG_implementation/pyg_implementaion_main_my_data_loader.py
@@ -11,11 +11,12 @@ from torch_scatter import scatter
 from torchmetrics.functional import jaccard_index
 
 import torch_geometric.transforms as T
-from torch_geometric.datasets import ShapeNet
 from torch_geometric.loader import DataLoader
 from torch_geometric.nn import MLP, DynamicEdgeConv
 
 
+from my_data_loader import MyData
+
 wandb_project = "pyg-point-cloud" #@param {"type": "string"} , maciej-wielgosz-nibio
 wandb_run_name = "train-dgcnn" #@param {"type": "string"}
 
@@ -35,13 +36,12 @@ random.seed(config.seed)
 torch.manual_seed(config.seed)
 device = torch.device(config.device)
 
-config.category = 'Car' #@param ["Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"] {type:"raw"}
 config.random_jitter_translation = 1e-2
 config.random_rotation_interval_x = 15
 config.random_rotation_interval_y = 15
 config.random_rotation_interval_z = 15
 config.validation_split = 0.2
-config.batch_size = 4
+config.batch_size = 1
 config.num_workers = 6
 
 config.num_nearest_neighbours = 30
@@ -62,23 +62,10 @@ transform = T.Compose([
 ])
 pre_transform = T.NormalizeScale()
 
-
-# dataset_path = os.path.join('ShapeNet', config.category)
-
-
-# train_val_dataset = ShapeNet(
-#     dataset_path, config.category, split='trainval',
-#     transform=transform, pre_transform=pre_transform
-# )
-
-from my_data_loader import MyData
-
 dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/maciek_data/plane_maciek"
+# dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/nibio_data_no_commas"
 
-train_val_dataset = MyData(
-    dataset_path, config.category, split='trainval',
-    transform=transform, pre_transform=pre_transform
-)
+train_val_dataset = MyData(dataset_path, split='trainval', transform=transform, pre_transform=pre_transform)
 
 
 segmentation_class_frequency = {}
@@ -111,7 +98,6 @@ visualization_loader = DataLoader(
 )
 
 
-
 class DGCNN(torch.nn.Module):
     def __init__(self, out_channels, k=30, aggr='max'):
         super().__init__()
@@ -197,9 +183,6 @@ def train_step(epoch):
     iou = torch.tensor(ious, device=device)
     category = torch.cat(categories, dim=0)
 
-    print("iou shape:", iou.shape)
-    print("category shape:", category.shape)
-
     mean_iou = float(scatter(iou, category, reduce='mean').mean())
 
     
@@ -236,10 +219,8 @@ def val_step(epoch):
         total_nodes += data.num_nodes
 
         sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
-        for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
-                                    data.category.tolist()):
-            category = list(ShapeNet.seg_classes.keys())[category]
-            part = ShapeNet.seg_classes[category]
+        for out, y in zip(outs.split(sizes), data.y.split(sizes)):
+            part = MyData.seg_classes
             part = torch.tensor(part, device=device)
 
             y_map[part] = torch.arange(part.size(0), device=device)
@@ -278,12 +259,10 @@ def visualization_step(epoch, table):
         y_map = torch.empty(
             visualization_loader.dataset.num_classes, device=device
         ).long()
-        for out, y, category in zip(
-            outs.split(sizes), data.y.split(sizes), data.category.tolist()
-        ):
-            category = list(ShapeNet.seg_classes.keys())[category]
-            part = ShapeNet.seg_classes[category]
+        for out, y in zip(outs.split(sizes), data.y.split(sizes)):
+            part = MyData.seg_classes
             part = torch.tensor(part, device=device)
+
             y_map[part] = torch.arange(part.size(0), device=device)
             iou = jaccard_index(
                 out[:, part].argmax(dim=-1), y_map[y],
diff --git a/PyG_implementation/remove_commas.py b/PyG_implementation/remove_commas.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cafefcf250cb0cb21f18327386716e289e1890c
--- /dev/null
+++ b/PyG_implementation/remove_commas.py
@@ -0,0 +1,31 @@
+import os
+import argparse
+from tqdm import tqdm
+
+def remove_commas(source_folder, target_folder):
+    # Create target folder if it doesn't exist
+    os.makedirs(target_folder, exist_ok=True)
+
+    # Iterate over all txt files in the source folder
+    for filename in os.listdir(source_folder):
+        if filename.endswith(".txt"):
+            # Open the source file and the target file
+            with open(os.path.join(source_folder, filename), 'r') as source_file, \
+                 open(os.path.join(target_folder, filename), 'w') as target_file:
+                # Read each line from the source file
+                for line in tqdm(source_file):
+                    # Replace commas with nothing
+                    line_without_commas = line.replace(',', ' ')
+                    # Write the new line to the target file
+                    target_file.write(line_without_commas)
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--source_folder", help="Folder with files that contain commas")
+    parser.add_argument("--target_folder", help="Folder with files that don't contain commas")
+    args = parser.parse_args()
+
+    remove_commas(args.source_folder, args.target_folder)
+
+    print(f"Files without commas are saved in : {args.target_folder}")
+    
\ No newline at end of file