diff --git a/dgcnn/dgcnn_main.py b/dgcnn/dgcnn_main_inference.py
similarity index 100%
rename from dgcnn/dgcnn_main.py
rename to dgcnn/dgcnn_main_inference.py
diff --git a/dgcnn/dgcnn_shape_net_inference.py b/dgcnn/dgcnn_shape_net_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..079a7c6750816dabccc4ede3943b92b3d144c196
--- /dev/null
+++ b/dgcnn/dgcnn_shape_net_inference.py
@@ -0,0 +1,21 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+
+from my_models.model_shape_net import DgcnShapeNet as DGCNN
+
+
+
+def main():
+    dgcnn = DGCNN(50)
+    dgcnn.eval()
+    # simple test
+    input_tensor = torch.randn(2, 128, 3)
+    label_one_hot = torch.zeros((2, 16))
+    print(input_tensor.shape)
+    out = dgcnn(input_tensor, label_one_hot)
+    print(out.shape)
+
+if __name__ == '__main__':
+    main()
diff --git a/dgcnn/dgcnn_train_shape_net.py b/dgcnn/dgcnn_train_shape_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d9b9b96e626f302ae832e399314fc0f723311f4
--- /dev/null
+++ b/dgcnn/dgcnn_train_shape_net.py
@@ -0,0 +1,116 @@
+import numpy as np
+import torch
+import os
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from matplotlib import pyplot as plt
+from shapenet_data_dgcnn import ShapenetDataDgcnn
+# import IoU from torchmetrics
+from torchmetrics import Accuracy
+from torchmetrics import JaccardIndex as IoU
+from torch.optim.lr_scheduler import StepLR
+
+import wandb
+
+use_wandb = True
+
+# create a wandb run
+if use_wandb:
+    wandb.init(project="dgcnn", entity="maciej-wielgosz-nibio")
+
+
+from my_models.model_shape_net import DgcnShapeNet as DGCNN
+
+
+def train():
+    seg_num_all = 50
+    dgcnn = DGCNN(seg_num_all).cuda()
+    dgcnn.train()
+    
+    # get data 
+    shapenet_data = ShapenetDataDgcnn(
+      root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+      npoints=1024,
+      return_cls_label=True,
+      small_data=False,
+      small_data_size=300,
+      just_four_classes=False,
+      data_augmentation=True,
+      split='train',
+      norm=True
+      )
+    
+    # create a dataloader
+    dataloader = torch.utils.data.DataLoader(
+        shapenet_data,
+        batch_size=4,
+        shuffle=True,
+        num_workers=8,
+        drop_last=True
+        )
+    
+    # create a optimizer
+    optimizer = torch.optim.Adam(dgcnn.parameters(), lr=0.001, weight_decay=1e-4)
+    scheduler = StepLR(optimizer, step_size=20, gamma=0.5)
+    if use_wandb:
+        # create a config wandb
+        wandb.config.update({
+            "batch_size": 8,
+            "learning_rate": 0.01,
+            "optimizer": "Adam",
+            "loss_function": "cross_entropy"
+            })
+        
+
+
+    # train
+    iou = IoU(num_classes=50, task='multiclass', average='macro').cuda()
+    acc = Accuracy(num_classes=50, compute_on_step=False, dist_sync_on_step=False, task='multiclass').cuda()
+
+    for epoch in range(500):
+        iou.reset()
+        print(f"Epoch: {epoch}")
+        if use_wandb:
+            wandb.log({"epoch": epoch})
+        for i, data in enumerate(dataloader, 0):
+            # print(f"Batch: {i}")
+         
+            points, labels, class_name = data
+            label_one_hot = np.zeros((class_name.shape[0], 16))
+            for idx in range(class_name.shape[0]):
+                label_one_hot[idx, class_name[idx]] = 1
+            label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
+
+            points = points.cuda()
+            labels = labels.cuda()
+            label_one_hot = label_one_hot.cuda()
+
+            optimizer.zero_grad()
+            pred = dgcnn(points, label_one_hot)
+
+            loss = F.cross_entropy(
+                pred, labels, reduction='mean', ignore_index=255)
+            loss.backward()
+            optimizer.step()
+            if optimizer.param_groups[0]['lr'] > 1e-5:
+                scheduler.step()
+
+            # print lose every 10 batches
+            if i % 100 == 0:
+                print(loss.item())
+                pred_softmax = F.softmax(pred, dim=1)
+                pred_argmax = torch.argmax(pred_softmax, dim=1)
+                gt_one_hot = F.one_hot(labels, num_classes=50)
+                pred_one_hot = F.one_hot(pred_argmax, num_classes=50)
+                print("loss : ", loss.item())
+                print("IoU : ",  iou(pred_one_hot, gt_one_hot))   
+                print("Acc : ",  acc(pred_argmax, labels)) 
+                if use_wandb:
+                    wandb.log({"loss": loss.item()})
+                    wandb.log({"iou": iou(pred, labels)})
+                    wandb.log({"acc": acc(pred, labels)})
+
+
+if __name__ == '__main__':
+    train()
\ No newline at end of file
diff --git a/dgcnn/model_class.py b/dgcnn/model_class.py
index 7a9abb348b850f7021a5c0fa61f3c7d75ce41d81..a066be019e96de73cd2e34a13e1e8952ef5173d2 100644
--- a/dgcnn/model_class.py
+++ b/dgcnn/model_class.py
@@ -68,50 +68,118 @@ class SelfAttention(nn.Module):
         out = out.view(batch_size, -1, num_points)
         return out
 
-class EdgeConvNewAtten(nn.Module):
+# class EdgeConvNewAtten(nn.Module):
+#     def __init__(self, in_channels, out_channels):
+#         super(EdgeConvNewAtten, self).__init__()
+#         self.in_channels = in_channels
+#         self.conv = nn.Sequential(
+#             nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
+#             nn.BatchNorm2d(out_channels),
+#             nn.LeakyReLU(negative_slope=0.2),
+#         )
+#         self.self_attention = SelfAttention(2*in_channels*20, num_heads=2, dropout=0.1)
+#         self.self_attention_feature = SelfAttention(in_channels*20, num_heads=4, dropout=0.1)
+
+#     def forward(self, x, k=20):
+#         batch_size = x.size(0)
+#         num_points = x.size(2)
+#         x = x.view(batch_size, -1, num_points)
+#         idx = self.knn(x, k=k)   # (batch_size, num_points, k)
+      
+#         idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
+
+#         idx = idx + idx_base
+
+#         idx = idx.view(-1)
+    
+#         _, num_dims, _ = x.size()
+
+#         x = x.transpose(2, 1).contiguous()   
+#         feature = x.view(batch_size*num_points, -1)[idx, :]
+#         # print("feature", feature.shape)
+#         feature = feature.view(batch_size, num_points, k, num_dims) 
+
+#         # add self attention
+#         # print("before reshape to the feature atten", feature.shape)
+
+#         # feature = feature.view(batch_size, num_points, -1) 
+#         # feature = self.self_attention_feature(feature)  # (batch_size, num_points, out_channels)
+#         # feature = feature.reshape(batch_size, num_points, k, -1).contiguous()
+    
+#         x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
+
+#         subtract = feature-x
+
+#         # subtract attention section
+#         subtract = subtract.view(batch_size, num_points, -1) 
+#         subtract = self.self_attention_feature(subtract)  # (batch_size, num_points, out_channels)
+#         subtract = feature.reshape(batch_size, num_points, k, -1).contiguous()
+        
+#         feature = torch.cat((subtract, x), dim=3).permute(0, 3, 1, 2).contiguous()
+
+#         feature = self.conv(feature) # (batch_size, num_dims, num_points, k)
+#         # print("feature before reshape to the first attention", feature.shape)
+
+#         # feature = feature.permute(0, 2, 1, 3).contiguous()
+#         # feature = feature.view(batch_size, num_points, -1) 
+#         # feature = self.self_attention(feature)  # (batch_size, num_points, out_channels)
+#         # feature = feature.reshape(batch_size, -1, num_points, k).contiguous()
+    
+#         return feature     
+
+#     def knn(self, x, k):
+#         x = x.transpose(2, 1)
+#         pairwise_distance = torch.cdist(x, x, p=2)
+#         _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False)  # (batch_size, num_points, k)
+#         return idx
+    
+
+class EdgeConvNewAttenFull(nn.Module):
     def __init__(self, in_channels, out_channels):
-        super(EdgeConvNewAtten, self).__init__()
+        super(EdgeConvNewAttenFull, self).__init__()
         self.in_channels = in_channels
         self.conv = nn.Sequential(
             nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
             nn.BatchNorm2d(out_channels),
             nn.LeakyReLU(negative_slope=0.2),
         )
-        self.self_attention = SelfAttention(2*in_channels*20, num_heads=8, dropout=0.1)
+        self.self_attention_pre = SelfAttention(in_channels, num_heads=1, dropout=0.1)
+        self.self_attention_feature = SelfAttention(2*in_channels, num_heads=2, dropout=0.1)
+        self.self_attention_post = SelfAttention(out_channels, num_heads=1, dropout=0.1)
 
     def forward(self, x, k=20):
         batch_size = x.size(0)
         num_points = x.size(2)
-        x = x.view(batch_size, -1, num_points)
+
+        # Apply self-attention before computing the KNN
+        x = self.self_attention_pre(x)
+
         idx = self.knn(x, k=k)   # (batch_size, num_points, k)
       
         idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
-
         idx = idx + idx_base
-
         idx = idx.view(-1)
     
         _, num_dims, _ = x.size()
-
-        x = x.transpose(2, 1).contiguous()   
+        x = x.transpose(2, 1).contiguous()
         feature = x.view(batch_size*num_points, -1)[idx, :]
         feature = feature.view(batch_size, num_points, k, num_dims) 
 
-        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
-        
-        feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
+        # Add self-attention to the feature differences
+        feature_diff = feature - x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
+        feature_diff = feature_diff.permute(0, 2, 1, 3).contiguous()
+        feature_diff = feature_diff.view(batch_size, k, -1)
+        feature_diff = self.self_attention_feature(feature_diff)
+        feature_diff = feature_diff.view(batch_size, k, num_points, num_dims).permute(0, 2, 1, 3).contiguous()
 
+        feature = torch.cat((feature_diff, x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)), dim=3).permute(0, 3, 1, 2).contiguous()
         feature = self.conv(feature) # (batch_size, num_dims, num_points, k)
-        # print("feature", feature.shape)
-        feature = feature.permute(0, 2, 1, 3).contiguous()
-        feature = feature.view(batch_size, num_points, -1) 
 
-        # print("feature", feature.shape)
+        # Apply self-attention after the convolution
+        feature = feature.permute(0, 2, 1, 3).contiguous()
+        feature = self.self_attention_post(feature)
+        feature = feature.view(batch_size, -1, num_points, k).contiguous()
 
-        feature = self.self_attention(feature)  # (batch_size, num_points, out_channels)
-        feature = feature.reshape(batch_size, -1, num_points, k).contiguous()
-        # print("feature", feature.shape)
-    
         return feature     
 
     def knn(self, x, k):
@@ -120,6 +188,8 @@ class EdgeConvNewAtten(nn.Module):
         _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False)  # (batch_size, num_points, k)
         return idx
 
+
+
 class EdgeConv(nn.Module):
     def __init__(self, in_channels, out_channels):
         super(EdgeConv, self).__init__()
@@ -200,6 +270,9 @@ class Transform_Net(nn.Module):
 
         return transform
 
+from my_models.egde_conv_new_atten import EdgeConvNewAtten
+
+
 class DgcnnClass(nn.Module):
     def __init__(self, num_classes):
         super(DgcnnClass, self).__init__()
@@ -224,8 +297,8 @@ class DgcnnClass(nn.Module):
     def forward(self, x):
         # Apply Transform_Net on input point cloud
 
-        trans_matrix = self.transform_net(x)
-        x = torch.bmm(x, trans_matrix)
+        # trans_matrix = self.transform_net(x)
+        # x = torch.bmm(x, trans_matrix)
 
         batch_size = x.size(0)
         num_points = x.size(1)
diff --git a/dgcnn/my_models/__init__.py b/dgcnn/my_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dgcnn/my_models/edge_conv_new.py b/dgcnn/my_models/edge_conv_new.py
new file mode 100644
index 0000000000000000000000000000000000000000..c498b2fd1266f3476dbef51b7f07ad753dbfdd65
--- /dev/null
+++ b/dgcnn/my_models/edge_conv_new.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+
+class EdgeConvNew(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(EdgeConvNew, self).__init__()
+        self.in_channels = in_channels
+        self.conv = nn.Sequential(
+            nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.LeakyReLU(negative_slope=0.2),
+        )
+
+    def forward(self, x, k=20):
+        batch_size = x.size(0)
+        num_points = x.size(2)
+        x = x.view(batch_size, -1, num_points)
+        idx = self.knn(x, k=k)   # (batch_size, num_points, k)
+      
+        idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
+
+        idx = idx + idx_base
+
+        idx = idx.view(-1)
+    
+        _, num_dims, _ = x.size()
+
+        x = x.transpose(2, 1).contiguous()   
+        feature = x.view(batch_size*num_points, -1)[idx, :]
+        feature = feature.view(batch_size, num_points, k, num_dims) 
+
+        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
+        
+        feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
+
+        feature = self.conv(feature) # (batch_size, num_dims, num_points, k)
+    
+        return feature     
+
+  
+    def knn(self, x, k):
+        x = x.transpose(2, 1)
+        pairwise_distance = torch.cdist(x, x, p=2)
+        _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False)  # (batch_size, num_points, k)
+        return idx
\ No newline at end of file
diff --git a/dgcnn/my_models/egde_conv_new_atten.py b/dgcnn/my_models/egde_conv_new_atten.py
new file mode 100644
index 0000000000000000000000000000000000000000..a82efb836d5af302d678d151cbb03c7f73b7cf5d
--- /dev/null
+++ b/dgcnn/my_models/egde_conv_new_atten.py
@@ -0,0 +1,86 @@
+from torch import nn
+import torch
+
+from my_models.self_attention import SelfAttention
+
+class EdgeConvNewAtten(nn.Module):
+    def __init__(self, in_channels, out_channels, k=20):
+        super(EdgeConvNewAtten, self).__init__()
+        self.in_channels = in_channels
+        self.k = k
+        self.conv = nn.Sequential(
+            nn.Conv2d(3*in_channels, out_channels, kernel_size=1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.LeakyReLU(negative_slope=0.2),
+        )
+        self.self_attention = SelfAttention(2*in_channels*self.k, num_heads=2, dropout=0.1)
+        self.self_attention_feature = SelfAttention(in_channels*self.k, num_heads=4, dropout=0.1)
+
+    def forward(self, x):
+        batch_size = x.size(0)
+        num_points = x.size(2)
+        x = x.view(batch_size, -1, num_points)
+        idx = self.knn(x, k=self.k)   # (batch_size, num_points, k)
+      
+        idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
+
+        idx = idx + idx_base
+
+        idx = idx.view(-1)
+    
+        _, num_dims, _ = x.size()
+
+        x = x.transpose(2, 1).contiguous()   
+        feature = x.view(batch_size*num_points, -1)[idx, :]
+        # print("feature", feature.shape)
+        feature = feature.view(batch_size, num_points, self.k, num_dims) 
+
+        # add self attention
+        # print("before reshape to the feature atten", feature.shape)
+
+        # feature = feature.view(batch_size, num_points, -1) 
+        # feature = self.self_attention_feature(feature)  # (batch_size, num_points, out_channels)
+        # feature = feature.reshape(batch_size, num_points, k, -1).contiguous()
+    
+        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, self.k, 1)
+
+        subtract = feature-x
+        subtract_sin = torch.sin(subtract) 
+        subtract_pow = torch.pow(subtract, 2)
+
+        # subtract attention section
+        subtract = subtract.view(batch_size, num_points, -1) 
+        subtract = self.self_attention_feature(subtract)  # (batch_size, num_points, out_channels)
+        subtract = subtract.reshape(batch_size, num_points, self.k, -1).contiguous()
+        
+
+        # subtract attention section
+        subtract_sin = subtract_sin.view(batch_size, num_points, -1) 
+        subtract_sin = self.self_attention_feature(subtract_sin)  # (batch_size, num_points, out_channels)
+        subtract_sin = subtract_sin.reshape(batch_size, num_points, self.k, -1).contiguous()
+
+        
+        # subtract attention section
+        subtract_pow = subtract_pow.view(batch_size, num_points, -1) 
+        subtract_pow = self.self_attention_feature(subtract_pow)  # (batch_size, num_points, out_channels)
+        subtract_pow = subtract_pow.reshape(batch_size, num_points, self.k, -1).contiguous()
+
+        # subtract = subtract + subtract_sin + subtract_pow
+
+        feature = torch.cat((subtract, subtract_pow, x), dim=3).permute(0, 3, 1, 2).contiguous()
+
+        feature = self.conv(feature) # (batch_size, num_dims, num_points, k)
+        # print("feature before reshape to the first attention", feature.shape)
+
+        # feature = feature.permute(0, 2, 1, 3).contiguous()
+        # feature = feature.view(batch_size, num_points, -1) 
+        # feature = self.self_attention(feature)  # (batch_size, num_points, out_channels)
+        # feature = feature.reshape(batch_size, -1, num_points, k).contiguous()
+    
+        return feature     
+
+    def knn(self, x, k):
+        x = x.transpose(2, 1)
+        pairwise_distance = torch.cdist(x, x, p=2)
+        _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False)  # (batch_size, num_points, k)
+        return idx
\ No newline at end of file
diff --git a/dgcnn/my_models/model_dist_atten.py b/dgcnn/my_models/model_dist_atten.py
new file mode 100644
index 0000000000000000000000000000000000000000..7556cb79e7372c9ff705b043396884bb0c2761bd
--- /dev/null
+++ b/dgcnn/my_models/model_dist_atten.py
@@ -0,0 +1,72 @@
+from torch import nn
+import torch
+
+from my_models.self_attention import SelfAttention
+
+class ModelDistAtten(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(ModelDistAtten, self).__init__()
+        self.in_channels = in_channels
+        self.conv = nn.Sequential(
+            nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.LeakyReLU(negative_slope=0.2),
+        )
+        self.self_attention = SelfAttention(2*in_channels*20, num_heads=2, dropout=0.1)
+        self.self_attention_feature = SelfAttention(in_channels*20, num_heads=4, dropout=0.1)
+        self.fc  = nn.Linear(2*in_channels, 2*in_channels)
+
+    def forward(self, x, k=20):
+        batch_size = x.size(0)
+        num_points = x.size(2)
+        x = x.view(batch_size, -1, num_points)
+        idx = self.knn(x, k=k)   # (batch_size, num_points, k)
+      
+        idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
+
+        idx = idx + idx_base
+
+        idx = idx.view(-1)
+    
+        _, num_dims, _ = x.size()
+
+        x = x.transpose(2, 1).contiguous()   
+        feature = x.view(batch_size*num_points, -1)[idx, :]
+        # print("feature", feature.shape)
+        feature = feature.view(batch_size, num_points, k, num_dims) 
+
+        # add self attention
+        # print("before reshape to the feature atten", feature.shape)
+
+        # feature = feature.view(batch_size, num_points, -1) 
+        # feature = self.self_attention_feature(feature)  # (batch_size, num_points, out_channels)
+        # feature = feature.reshape(batch_size, num_points, k, -1).contiguous()
+    
+        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
+
+        subtract = feature-x
+        # take power of the distance
+        subtract = torch.pow(subtract, 2)
+
+        # subtract attention section
+        subtract = subtract.view(batch_size, num_points, -1) 
+        subtract = self.self_attention_feature(subtract)  # (batch_size, num_points, out_channels)
+        subtract = feature.reshape(batch_size, num_points, k, -1).contiguous()
+        
+        feature = torch.cat((subtract, x), dim=3).permute(0, 3, 1, 2).contiguous()
+
+        feature = self.conv(feature) # (batch_size, num_dims, num_points, k)
+        # print("feature before reshape to the first attention", feature.shape)
+
+        # feature = feature.permute(0, 2, 1, 3).contiguous()
+        # feature = feature.view(batch_size, num_points, -1) 
+        # feature = self.self_attention(feature)  # (batch_size, num_points, out_channels)
+        # feature = feature.reshape(batch_size, -1, num_points, k).contiguous()
+    
+        return feature     
+
+    def knn(self, x, k):
+        x = x.transpose(2, 1)
+        pairwise_distance = torch.cdist(x, x, p=2)
+        _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False)  # (batch_size, num_points, k)
+        return idx
\ No newline at end of file
diff --git a/dgcnn/my_models/model_shape_net.py b/dgcnn/my_models/model_shape_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..35935705cb985f61d8e0899808292b9eefaaadb3
--- /dev/null
+++ b/dgcnn/my_models/model_shape_net.py
@@ -0,0 +1,83 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from my_models.transform_net import Transform_Net
+from my_models.edge_conv_new import EdgeConvNew
+
+
+class DgcnShapeNet(nn.Module):
+    def __init__(self, seg_num_all):
+        super(DgcnShapeNet, self).__init__()
+        self.seg_num_all = seg_num_all
+        self.transform_net = Transform_Net()
+        self.edge_conv1 = EdgeConvNew(3, 64)
+        self.edge_conv2 = EdgeConvNew(64, 128)
+        self.bn5 = nn.BatchNorm1d(256)
+        self.bn6 = nn.BatchNorm1d(256)
+        self.bn7 = nn.BatchNorm1d(64)
+        self.bn8 = nn.BatchNorm1d(256)
+        self.bn9 = nn.BatchNorm1d(256)
+        self.bn10 = nn.BatchNorm1d(128)
+        self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False),
+                                   self.bn5,
+                                   nn.LeakyReLU(negative_slope=0.2))
+        self.conv6 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
+                                   self.bn6,
+                                   nn.LeakyReLU(negative_slope=0.2))
+        self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
+                                   self.bn7,
+                                   nn.LeakyReLU(negative_slope=0.2))
+        self.conv8 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False),
+                                   self.bn8,
+                                   nn.LeakyReLU(negative_slope=0.2))
+        self.dp1 = nn.Dropout(p=0.5)
+        self.conv9 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
+                                   self.bn9,
+                                   nn.LeakyReLU(negative_slope=0.2))
+        self.dp2 = nn.Dropout(p=0.5)
+        self.conv10 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False),
+                                   self.bn10,
+                                   nn.LeakyReLU(negative_slope=0.2))
+        self.conv11 = nn.Conv1d(128, self.seg_num_all, kernel_size=1, bias=False)
+    
+
+
+    def forward(self, x, class_label):
+        # Apply Transform_Net on input point cloud
+
+        trans_matrix = self.transform_net(x)
+        x = torch.bmm(x, trans_matrix)
+
+        batch_size = x.size(0)
+        num_points = x.size(1)
+        dim = x.size(2)
+
+        x = x.view(batch_size, dim, num_points)
+      
+        x1 = self.edge_conv1(x)
+        x1 = x1.max(dim=-1, keepdim=False)[0] 
+        x2 = self.edge_conv2(x1)
+        x2 = x2.max(dim=-1, keepdim=False)[0]
+        x5 = torch.cat((x1, x2), dim=1)  # (batch_size, 64+64+128+256, num_points)
+        x = self.conv5(x5)                     # (batch_size, 1024, num_points)   
+        x = x.max(dim=-1, keepdim=True)[0]    # (batch_size, 1024)
+
+        class_label = class_label.view(batch_size, -1, 1)           # (batch_size, num_categoties, 1)
+        class_label = self.conv7(class_label)                       # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)
+
+        x = torch.cat((x, class_label), dim=1)            # (batch_size, 1088, 1)
+        x = x.repeat(1, 1, num_points)          # (batch_size, 1088, num_points)
+
+        x = torch.cat((x, x1, x2), dim=1)   # (batch_size, 1088+64*3, num_points)
+
+        x = self.conv8(x)                       # (batch_size, 1088+64*3, num_points) -> (batch_size, 256, num_points)
+        x = self.dp1(x)
+        x = self.conv9(x)                       # (batch_size, 256, num_points) -> (batch_size, 256, num_points)
+        x = self.dp2(x)
+        x = self.conv10(x)                      # (batch_size, 256, num_points) -> (batch_size, 128, num_points)
+        x = self.conv11(x)                      # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points)
+
+        return x
+   
+ 
\ No newline at end of file
diff --git a/dgcnn/my_models/self_attention.py b/dgcnn/my_models/self_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..62626a0126ba0cce1c688123b06b672b5fded006
--- /dev/null
+++ b/dgcnn/my_models/self_attention.py
@@ -0,0 +1,21 @@
+from torch import nn
+from torch.nn import MultiheadAttention
+
+# implement self attention
+class SelfAttention(nn.Module):
+    def __init__(self, in_channels, num_heads, dropout):
+        super(SelfAttention, self).__init__()
+        self.in_channels = in_channels
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.self_attention = MultiheadAttention(in_channels, num_heads=num_heads, dropout=dropout)
+
+    def forward(self, x):
+        batch_size = x.size(0)
+        num_points = x.size(2)
+        x = x.view(batch_size, -1, num_points)
+        x = x.permute(1, 0, 2)
+        out, attn = self.self_attention(x, x, x)
+        out = out.permute(1, 0, 2)
+        out = out.view(batch_size, -1, num_points)
+        return out
\ No newline at end of file
diff --git a/dgcnn/my_models/transform_net.py b/dgcnn/my_models/transform_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e0e531f2379527844e883abf1ec84c581e05c23
--- /dev/null
+++ b/dgcnn/my_models/transform_net.py
@@ -0,0 +1,45 @@
+from torch import nn
+import torch 
+import torch.nn.functional as F
+
+
+class Transform_Net(nn.Module):
+    def __init__(self, k=3):
+        super(Transform_Net, self).__init__()
+        self.k = k
+
+        self.conv1 = nn.Conv1d(k, 64, 1)
+        self.conv2 = nn.Conv1d(64, 128, 1)
+        self.conv3 = nn.Conv1d(128, 256, 1)
+
+   
+        self.fc3 = nn.Linear(256, k*k)
+
+        self.bn1 = nn.BatchNorm1d(64)
+        self.bn2 = nn.BatchNorm1d(128)
+        self.bn3 = nn.BatchNorm1d(256)
+        self.bn4 = nn.BatchNorm1d(128)
+        self.bn5 = nn.BatchNorm1d(64)
+        self.dropout_0 = nn.Dropout(p=0.5)
+        self.dropout_1 = nn.Dropout(p=0.5)
+        self.dropout_2 = nn.Dropout(p=0.5)
+
+    def forward(self, x):
+        # Input shape: (batch_size, k, num_points)
+        x = x.transpose(2, 1)
+        x = F.relu(self.bn1(self.conv1(x)))
+        x = F.relu(self.bn2(self.conv2(x)))
+        x = F.relu(self.bn3(self.conv3(x)))
+
+        x = torch.max(x, 2, keepdim=True)[0]
+        x = x.view(-1, 256)
+
+        x = self.fc3(x)
+        x = self.dropout_2(x)
+
+        identity = torch.eye(self.k, dtype=x.dtype, device=x.device)
+        transform = x.view(-1, self.k, self.k) + identity
+
+        transform = transform.transpose(2, 1)
+
+        return transform
\ No newline at end of file
diff --git a/nibio_transformer_semantic/tile_point_clouds_in_folders.py b/nibio_transformer_semantic/tile_point_clouds_in_folders.py
new file mode 100644
index 0000000000000000000000000000000000000000..69898f23988c118854c52866356a3cbf7587d3c5
--- /dev/null
+++ b/nibio_transformer_semantic/tile_point_clouds_in_folders.py
@@ -0,0 +1,18 @@
+import os
+
+import pandas as pd
+
+class TilePointCloundsInFolders(object):
+    def __init__(self, dir) -> None:
+        self.dir = dir
+        self.list_of_all_txt_files = [file for file in os.listdir(dir) if file.endswith('.txt')]
+
+    def process_single_file(self, file):
+        df = pd.read_csv(file, header=None, sep=' ')
+        df = df.iloc[:, [0, 1, 2, -2]]
+        df.columns = ['x', 'y', 'z', 'label']
+
+    def process_all_files(self):
+        for file in self.list_of_all_txt_files:
+            self.process_single_file(file)
+