diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py
index 6db78aae24565ad018b4ef18a1dea3424f7db72d..02023e7240a490f1f84f59442411eb2b46e325a5 100644
--- a/dgcnn/dgcnn_train_pl.py
+++ b/dgcnn/dgcnn_train_pl.py
@@ -141,7 +141,7 @@ shapenet_data_train = ShapenetDataDgcnn(
       just_one_class=config['data']['just_one_class'],
       split='train',
       norm=config['data']['norm'],
-      augmnetation=True
+      augmnetation=config['data']['augmentation']
       )
 
 # get val data
@@ -152,7 +152,7 @@ shapenet_data_val = ShapenetDataDgcnn(
         small_data=config['data']['small_data'],
         small_data_size=config['data']['small_data_size'],
         just_one_class=config['data']['just_one_class'],
-        split='train',
+        split='test',
         norm=config['data']['norm']
         )
 
diff --git a/dgcnn/model_class.py b/dgcnn/model_class.py
index 10ccfce370db7fcf40d210f38fe8011a91f4b48b..31878e728bbad5bad7d14c3e2c1a5341f5a82174 100644
--- a/dgcnn/model_class.py
+++ b/dgcnn/model_class.py
@@ -5,7 +5,48 @@ import torch.nn.init as init
 
 # TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
 # TODO: There are mode conv layers there
+class EdgeConvNew(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(EdgeConvNew, self).__init__()
+        self.in_channels = in_channels
+        self.conv = nn.Sequential(
+            nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.LeakyReLU(negative_slope=0.2),
+        )
+
+    def forward(self, x, k=20):
+        batch_size = x.size(0)
+        num_points = x.size(2)
+        x = x.view(batch_size, -1, num_points)
+        idx = self.knn(x, k=k)   # (batch_size, num_points, k)
+      
+        idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
+
+        idx = idx + idx_base
+
+        idx = idx.view(-1)
+    
+        _, num_dims, _ = x.size()
+
+        x = x.transpose(2, 1).contiguous()   
+        feature = x.view(batch_size*num_points, -1)[idx, :]
+        feature = feature.view(batch_size, num_points, k, num_dims) 
+
+        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
+        
+        feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
 
+        feature = self.conv(feature) # (batch_size, num_dims, num_points, k)
+    
+        return feature     
+
+  
+    def knn(self, x, k):
+        x = x.transpose(2, 1)
+        pairwise_distance = torch.cdist(x, x, p=2)
+        _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False)  # (batch_size, num_points, k)
+        return idx
 
 class EdgeConv(nn.Module):
     def __init__(self, in_channels, out_channels):
@@ -91,8 +132,8 @@ class DgcnnClass(nn.Module):
     def __init__(self, num_classes):
         super(DgcnnClass, self).__init__()
         self.transform_net = Transform_Net()
-        self.edge_conv1 = EdgeConv(3, 64)
-        self.edge_conv2 = EdgeConv(64, 128)
+        self.edge_conv1 = EdgeConvNew(3, 64)
+        self.edge_conv2 = EdgeConvNew(64, 128)
         self.bn5 = nn.BatchNorm1d(256)
         self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False),
                                    self.bn5,
@@ -110,18 +151,25 @@ class DgcnnClass(nn.Module):
 
     def forward(self, x):
         # Apply Transform_Net on input point cloud
-        batch_size = x.size(0)
 
         trans_matrix = self.transform_net(x)
         x = torch.bmm(x, trans_matrix)
+
+        batch_size = x.size(0)
+        num_points = x.size(1)
+        dim = x.size(2)
+
+        x = x.view(batch_size, dim, num_points)
+
+      
         x1 = self.edge_conv1(x)
         x1 = x1.max(dim=-1, keepdim=False)[0] 
         # print("x1 shape: ", x1.shape)
         x2 = self.edge_conv2(x1)
         x2 = x2.max(dim=-1, keepdim=False)[0]
         # print("x2 shape: ", x2.shape)
-        x5 = torch.cat((x1, x2), dim=2)  # (batch_size, 64+64+128+256, num_points)
-        x5 = x5.transpose(2, 1)                 # (batch_size, num_points, 64+64+128+256)
+        x5 = torch.cat((x1, x2), dim=1)  # (batch_size, 64+64+128+256, num_points)
+        # x5 = x5.transpose(2, 1)                 # (batch_size, num_points, 64+64+128+256)
         # print("x5 shape: ", x5.shape)
         x_conv = self.conv5(x5)                     # (batch_size, 1024, num_points)   
         # print("x_conv shape: ", x_conv.shape)
diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py
index d6f6b029c03281d7726ab15d6906864306ede860..32ed47a1cdf706c81b66b7ac41d786511fb96bb2 100644
--- a/dgcnn/shapenet_data_dgcnn.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -83,21 +83,26 @@ class ShapenetDataDgcnn(object):
         with open(json_file, 'r') as f:
             data = json.load(f)
 
+        print('10 data in the list: ', data[:10])
+
+        out_data = []
         for i in range(len(data)):
-            data[i] = os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt'))
+            out_data.append(os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt')))
 
         # get one class of data
         # get the the number of the class airplane
 
         if self.just_one_class:
-            data = [x for x in data if x.split('/')[-2] in [
+            out_data = [x for x in out_data if x.split('/')[-2] in [
                 self.cat['Airplane'],
                 self.cat['Lamp'],
                 self.cat['Chair'],
                 self.cat['Table'],
                 ]]
+        
+        print('10 data in the out_data list: ', out_data[:10])
 
-        return data
+        return out_data
     
     def get_seg_classes(self, cat):
         return self.seg_classes[cat]
@@ -206,7 +211,13 @@ class ShapenetDataDgcnn(object):
         labels = labels.astype(np.int64)
 
         # get the class name
-        class_name = self.train_file_list[index].split('/')[-2]
+        if self.split == 'train':
+            class_name = self.train_file_list[index].split('/')[-2]
+        elif self.split == 'test':
+            class_name = self.test_file_list[index].split('/')[-2]
+        elif self.split == 'val':
+            class_name = self.val_data_file[index].split('/')[-2]
+            
         # apply the mapper
         class_name = self.class_mapper(class_name)