diff --git a/.gitignore b/.gitignore
index 854332645cdb38e7cf6a9adfae02c9cf4c74beaa..dcd6c13bb9d13b3edcfa447f9ce246d28bff16b2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -28,4 +28,6 @@ lightning_logs
 cifar-10-batches-py
 *.ckpt
 *.gz
-MNIST
\ No newline at end of file
+MNIST
+results
+ModelNet10
\ No newline at end of file
diff --git a/dgcnn/dgcnn_main.py b/dgcnn/dgcnn_main.py
index 1b11c347989853803beb4f297cd01fa48c10ae85..5242c00d3418d79f8033329c4f8377f9b06989a6 100644
--- a/dgcnn/dgcnn_main.py
+++ b/dgcnn/dgcnn_main.py
@@ -16,7 +16,6 @@ def main():
     print(input_tensor.shape)
     out = dgcnn(input_tensor)
     print(out.shape)
-    print(out)
 
 if __name__ == '__main__':
     main()
diff --git a/dgcnn/edge_conv_layer_run.ipynb b/dgcnn/edge_conv_layer_run.ipynb
index e61e083164c09fe387fa05ef600cdf1d83d8f825..b27286d72afa4b2d9a6658ae9a3ee4d2b5fe0eb9 100644
--- a/dgcnn/edge_conv_layer_run.ipynb
+++ b/dgcnn/edge_conv_layer_run.ipynb
@@ -2,14 +2,14 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 27,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Output shape: torch.Size([1, 2048, 32, 10])\n"
+      "Output shape: torch.Size([1, 1024, 32, 10])\n"
      ]
     }
    ],
@@ -23,8 +23,9 @@
     "class EdgeConv(nn.Module):\n",
     "    def __init__(self, in_channels, out_channels):\n",
     "        super(EdgeConv, self).__init__()\n",
+    "        self.in_channels = in_channels\n",
     "        self.conv = nn.Sequential(\n",
-    "            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),\n",
+    "            nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),\n",
     "            nn.BatchNorm2d(out_channels),\n",
     "            nn.ReLU()\n",
     "        )\n",
@@ -37,7 +38,7 @@
     "        knn_indices = self.knn(x, k)\n",
     "        knn_gathered = self.gather_neighbors(x, knn_indices)\n",
     "        edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)\n",
-    "        edge_features = edge_features.view(batch_size, feature_dim, 2*num_points, k)\n",
+    "        edge_features = edge_features.view(batch_size, 2*feature_dim, num_points, k)\n",
     "        return self.conv(edge_features).transpose(2, 1)\n",
     "\n",
     "    @staticmethod\n",
@@ -50,12 +51,11 @@
     "        _, indices = distances.topk(k=k, dim=-1, largest=False)\n",
     "        return indices\n",
     "\n",
-    "    @staticmethod\n",
-    "    def gather_neighbors(x, knn_indices):\n",
+    "    def gather_neighbors(self, x, knn_indices):\n",
     "        batch_size, num_points, _ = x.shape\n",
     "        _, _, k = knn_indices.shape\n",
     "        x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)\n",
-    "        neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, 3))\n",
+    "        neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, self.in_channels))\n",
     "        return neighbors\n",
     "    \n",
     "\n",
diff --git a/dgcnn/model.py b/dgcnn/model.py
index 9d3c3b4b1551f832ce4aa8abded78911dbb2ede5..3d96217d76dd3d4b48b2a744d3788fb17c45947b 100644
--- a/dgcnn/model.py
+++ b/dgcnn/model.py
@@ -11,20 +11,21 @@ class EdgeConv(nn.Module):
     def __init__(self, in_channels, out_channels):
         super(EdgeConv, self).__init__()
         self.in_channels = in_channels
-
         self.conv = nn.Sequential(
-            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
+            nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
             nn.BatchNorm2d(out_channels),
-            nn.ReLU() # TODO: replace with leaky relu
+            nn.ReLU()
         )
 
     def forward(self, x, k=20):
+        #batch_size, num_points, in_channels
+
         batch_size, num_points, feature_dim = x.shape
         x = x.view(batch_size, num_points,feature_dim )
         knn_indices = self.knn(x, k)
         knn_gathered = self.gather_neighbors(x, knn_indices)
         edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
-        edge_features = edge_features.view(batch_size, feature_dim, 2*num_points, k)
+        edge_features = edge_features.view(batch_size, 2*feature_dim, num_points, k)
         return self.conv(edge_features).transpose(2, 1)
 
     @staticmethod
@@ -93,6 +94,12 @@ class DGCNN(nn.Module):
         self.edge_conv2 = EdgeConv(64, 128)
         self.edge_conv3 = EdgeConv(128, 256)
         self.edge_conv4 = EdgeConv(256, 512)
+        self.bn5 = nn.BatchNorm1d(1024)
+        self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
+                                   self.bn5,
+                                   nn.LeakyReLU(negative_slope=0.2))
+        self.linear1 = nn.Linear(2048, 512, bias=False)
+
         self.fc = nn.Sequential(
             nn.Linear(512, 256),
             nn.BatchNorm1d(256),
@@ -107,17 +114,29 @@ class DGCNN(nn.Module):
 
     def forward(self, x):
         # Apply Transform_Net on input point cloud
+        batch_size = x.size(0)
 
         trans_matrix = self.transform_net(x)
         x = torch.bmm(x, trans_matrix)
         x1 = self.edge_conv1(x)
         x1 = x1.max(dim=-1, keepdim=False)[0] 
+        # print("x1 shape: ", x1.shape)
         x2 = self.edge_conv2(x1)
         x2 = x2.max(dim=-1, keepdim=False)[0]
+        # print("x2 shape: ", x2.shape)
         x3 = self.edge_conv3(x2)
         x3 = x3.max(dim=-1, keepdim=False)[0]
+        # print("x3 shape: ", x3.shape)
         x4 = self.edge_conv4(x3)
         x4 = x4.max(dim=-1, keepdim=False)[0]
-        x = torch.max(x4, dim=1, keepdim=True)[0]
-        x = self.fc(x.squeeze(1))
-        return x
\ No newline at end of file
+        # print("x4 shape: ", x4.shape)
+        # x5 = torch.cat((x1, x2, x3, x4), dim=1)  # (batch_size, 64+64+128+256, num_points)
+        # x6 = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
+        # x7 = F.adaptive_avg_pool1d(x5, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
+        # x8 = torch.cat((x6, x7), 1)              # (batch_size, emb_dims*2)
+
+        # x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
+
+        x9 = torch.max(x4, dim=1, keepdim=True)[0]
+        x10 = self.fc(x9.squeeze(1))
+        return x10
\ No newline at end of file
diff --git a/point_net_autoencoder/shapenet_data.py b/point_net_autoencoder/shapenet_data.py
index 7cd999b81b4c4b07bb4f90af9feb410dc77de56a..7e5a3fe22a0251d04c3ba78f2ff8bae5629979ac 100644
--- a/point_net_autoencoder/shapenet_data.py
+++ b/point_net_autoencoder/shapenet_data.py
@@ -16,7 +16,7 @@ class ShapenetData(object):
                  small_data=False,
                  small_data_size=10,
                  just_one_class=False,
-                 norm=True,
+                 norm=False,
                  data_augmentation=False
                  ) -> None:
         
@@ -115,6 +115,9 @@ class ShapenetData(object):
         elif self.split == 'val':
             point_set = np.loadtxt(self.val_data_file[index]).astype(np.float32)
 
+        # get labels
+        labels = point_set[:, 6]
+
         # get just the points
         point_set = point_set[:, 0:3]
 
@@ -122,14 +125,18 @@ class ShapenetData(object):
         if self.norm:
             point_set = self.normalize(point_set)
 
-        # choice = np.random.choice(len(point_set), self.npoints, replace=True)
+        choice = np.random.choice(len(point_set), self.npoints, replace=True)
         # chose the first npoints
         choice = np.arange(self.npoints)
 
         point_set = point_set[choice, :]
         point_set = point_set.astype(np.float32)
 
-        return point_set
+        # get the labels
+        labels = labels[choice]
+        labels = labels.astype(np.int64)
+
+        return point_set, labels
 
     def __len__(self):
         if self.split == 'train':
@@ -142,7 +149,10 @@ class ShapenetData(object):
 
 if __name__ == "__main__":
 
-  shapenet_data = ShapenetData(root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', split='train')
+  shapenet_data = ShapenetData(
+      root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+      split='train'
+      )
 #   print(shapenet_data.train_file_list)
   print(shapenet_data.get_seg_classes('Car'))
   print(shapenet_data.get_class_names())