diff --git a/dgcnn/__init__.py b/dgcnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dgcnn/dgcnn_main.py b/dgcnn/dgcnn_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..7173890db46f452dcd53600dfc9b5b530d134baf
--- /dev/null
+++ b/dgcnn/dgcnn_main.py
@@ -0,0 +1,21 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+
+from model import DGCNN
+
+
+
+def main():
+    num_classes = 10
+    dgcnn = DGCNN(num_classes)
+    dgcnn.eval()
+    # simple test
+    input_tensor = torch.randn(1, 128, 3)
+    print(input_tensor.shape)
+    out = dgcnn(input_tensor)
+    print(out.shape)
+
+if __name__ == '__main__':
+    main()
diff --git a/dgcnn/edge_conv_layer_run.ipynb b/dgcnn/edge_conv_layer_run.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..e61e083164c09fe387fa05ef600cdf1d83d8f825
--- /dev/null
+++ b/dgcnn/edge_conv_layer_run.ipynb
@@ -0,0 +1,234 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Output shape: torch.Size([1, 2048, 32, 10])\n"
+     ]
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torch.nn.init as init\n",
+    "\n",
+    "\n",
+    "class EdgeConv(nn.Module):\n",
+    "    def __init__(self, in_channels, out_channels):\n",
+    "        super(EdgeConv, self).__init__()\n",
+    "        self.conv = nn.Sequential(\n",
+    "            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),\n",
+    "            nn.BatchNorm2d(out_channels),\n",
+    "            nn.ReLU()\n",
+    "        )\n",
+    "\n",
+    "    def forward(self, x, k=20):\n",
+    "        #batch_size, num_points, in_channels\n",
+    "\n",
+    "        batch_size, num_points, feature_dim = x.shape\n",
+    "        x = x.view(batch_size, num_points,feature_dim )\n",
+    "        knn_indices = self.knn(x, k)\n",
+    "        knn_gathered = self.gather_neighbors(x, knn_indices)\n",
+    "        edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)\n",
+    "        edge_features = edge_features.view(batch_size, feature_dim, 2*num_points, k)\n",
+    "        return self.conv(edge_features).transpose(2, 1)\n",
+    "\n",
+    "    @staticmethod\n",
+    "    def knn(x, k):\n",
+    "        \"\"\"Find the indices of the k nearest neighbors for each point in the input tensor.\"\"\"\n",
+    "        batch_size, num_points, _ = x.shape\n",
+    "        x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)\n",
+    "        x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)\n",
+    "        distances = torch.norm(x_expanded - x_transposed, dim=-1)\n",
+    "        _, indices = distances.topk(k=k, dim=-1, largest=False)\n",
+    "        return indices\n",
+    "\n",
+    "    @staticmethod\n",
+    "    def gather_neighbors(x, knn_indices):\n",
+    "        batch_size, num_points, _ = x.shape\n",
+    "        _, _, k = knn_indices.shape\n",
+    "        x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)\n",
+    "        neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, 3))\n",
+    "        return neighbors\n",
+    "    \n",
+    "\n",
+    "# Generate a sample input vector\n",
+    "batch_size = 1\n",
+    "in_channels = 3\n",
+    "num_points = 1024\n",
+    "\n",
+    "point_cloud = torch.randn(batch_size, in_channels, num_points)\n",
+    "\n",
+    "# Create an EdgeConv layer\n",
+    "out_channels = 32  # You can choose the desired number of output channels\n",
+    "edge_conv_layer = EdgeConv(in_channels, out_channels)\n",
+    "edge_conv_layer.eval()\n",
+    "\n",
+    "# Pass the sample input vector through the EdgeConv layer\n",
+    "k = 10\n",
+    "output = edge_conv_layer(point_cloud.view(batch_size, num_points, in_channels), k)\n",
+    "print(\"Output shape:\", output.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def knn(x, k):\n",
+    "    inner = -2 * torch.matmul(x.transpose(2, 1), x)\n",
+    "    xx = torch.sum(x ** 2, dim=1, keepdim=True)\n",
+    "    pairwise_distance = -xx - inner - xx.transpose(2, 1)\n",
+    "    knn_indices = pairwise_distance.topk(k=k, dim=-1)[1]\n",
+    "    return knn_indices\n",
+    "\n",
+    "import torch\n",
+    "\n",
+    "# Set sample input parameters\n",
+    "batch_size = 2\n",
+    "num_points = 1024\n",
+    "feature_dim = 3\n",
+    "k = 20\n",
+    "\n",
+    "# Create a random input tensor\n",
+    "x = torch.randn(batch_size, num_points, feature_dim)\n",
+    "\n",
+    "print(\"Sample input vector x:\")\n",
+    "print(x.shape)\n",
+    "\n",
+    "# Compute the k nearest neighbors\n",
+    "knn_indices = knn(x, k)\n",
+    "\n",
+    "print(\"k nearest neighbors indices:\")   \n",
+    "print(knn_indices.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Sample input vector x:\n",
+      "torch.Size([1, 1024, 3])\n",
+      "k nearest neighbors indices:\n",
+      "torch.Size([1, 1024, 20])\n"
+     ]
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "\n",
+    "def knn(x, k):\n",
+    "    \"\"\"Find the indices of the k nearest neighbors for each point in the input tensor.\"\"\"\n",
+    "    batch_size, num_points, _ = x.shape\n",
+    "    x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)\n",
+    "    x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)\n",
+    "    distances = torch.norm(x_expanded - x_transposed, dim=-1)\n",
+    "    _, indices = distances.topk(k=k, dim=-1, largest=False)\n",
+    "    return indices\n",
+    "\n",
+    "# Set sample input parameters\n",
+    "batch_size = 1\n",
+    "num_points = 1024\n",
+    "feature_dim = 3\n",
+    "k = 20\n",
+    "\n",
+    "# Create a random input tensor\n",
+    "x = torch.randn(batch_size, num_points, feature_dim)\n",
+    "\n",
+    "print(\"Sample input vector x:\")\n",
+    "print(x.shape)\n",
+    "\n",
+    "# Compute the k nearest neighbors\n",
+    "knn_indices = knn(x, k)\n",
+    "\n",
+    "print(\"k nearest neighbors indices:\")   \n",
+    "print(knn_indices.shape)\n",
+    "\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "x_expanded: torch.Size([1, 1024, 1024, 3])\n",
+      "tensor([-0.7148, -0.1634, -1.3527])\n",
+      "tensor([-0.7148, -0.1634, -1.3527])\n"
+     ]
+    }
+   ],
+   "source": [
+    "x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)\n",
+    "\n",
+    "print(\"x_expanded:\", x_expanded.shape)\n",
+    "\n",
+    "print(x_expanded[0, 0, 0, :])\n",
+    "print(x[0, 0, :])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "knn_indices: torch.Size([1, 1024, 20])\n",
+      "neighbors: torch.Size([1, 1024, 20, 3])\n"
+     ]
+    }
+   ],
+   "source": [
+    "print('knn_indices:', knn_indices.shape)\n",
+    "\n",
+    "# Gather the neighbors\n",
+    "neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, 3))\n",
+    "\n",
+    "print(\"neighbors:\", neighbors.shape)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.10"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/dgcnn/model.py b/dgcnn/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..366dd16da831aa55401ee26764450791ffdcebb3
--- /dev/null
+++ b/dgcnn/model.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+
+#TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
+
+class EdgeConv(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(EdgeConv, self).__init__()
+        self.in_channels = in_channels
+
+        self.conv = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU()
+        )
+
+    def forward(self, x, k=20):
+        batch_size, num_points, feature_dim = x.shape
+        x = x.view(batch_size, num_points,feature_dim )
+        knn_indices = self.knn(x, k)
+        knn_gathered = self.gather_neighbors(x, knn_indices)
+        edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
+        edge_features = edge_features.view(batch_size, feature_dim, 2*num_points, k)
+        return self.conv(edge_features).transpose(2, 1)
+
+    @staticmethod
+    def knn(x, k):
+        """Find the indices of the k nearest neighbors for each point in the input tensor."""
+        batch_size, num_points, _ = x.shape
+        x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
+        x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
+        distances = torch.norm(x_expanded - x_transposed, dim=-1)
+        _, indices = distances.topk(k=k, dim=-1, largest=False)
+        return indices
+    
+    def gather_neighbors(self, x, knn_indices):
+        batch_size, num_points, _ = x.shape
+        _, _, k = knn_indices.shape
+        x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
+        neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, self.in_channels))
+        return neighbors
+    
+    
+class Transform_Net(nn.Module):
+    def __init__(self, k=3):
+        super(Transform_Net, self).__init__()
+        self.k = k
+
+        self.conv1 = nn.Conv1d(k, 64, 1)
+        self.conv2 = nn.Conv1d(64, 128, 1)
+        self.conv3 = nn.Conv1d(128, 1024, 1)
+
+        self.fc1 = nn.Linear(1024, 512)
+        self.fc2 = nn.Linear(512, 256)
+        self.fc3 = nn.Linear(256, k*k)
+
+        self.bn1 = nn.BatchNorm1d(64)
+        self.bn2 = nn.BatchNorm1d(128)
+        self.bn3 = nn.BatchNorm1d(1024)
+        self.bn4 = nn.BatchNorm1d(512)
+        self.bn5 = nn.BatchNorm1d(256)
+
+    def forward(self, x):
+        # Input shape: (batch_size, k, num_points)
+        x = x.transpose(2, 1)
+        x = F.relu(self.bn1(self.conv1(x)))
+        x = F.relu(self.bn2(self.conv2(x)))
+        x = F.relu(self.bn3(self.conv3(x)))
+
+        x = torch.max(x, 2, keepdim=True)[0]
+        x = x.view(-1, 1024)
+
+        x = F.relu(self.bn4(self.fc1(x)))
+        x = F.relu(self.bn5(self.fc2(x)))
+        x = self.fc3(x)
+
+        identity = torch.eye(self.k, dtype=x.dtype, device=x.device)
+        transform = x.view(-1, self.k, self.k) + identity
+
+        transform = transform.transpose(2, 1)
+
+        return transform
+
+class DGCNN(nn.Module):
+    def __init__(self, num_classes):
+        super(DGCNN, self).__init__()
+        self.transform_net = Transform_Net()
+        self.edge_conv1 = EdgeConv(3, 64)
+        self.edge_conv2 = EdgeConv(64, 128)
+        self.edge_conv3 = EdgeConv(128, 256)
+        self.edge_conv4 = EdgeConv(256, 512)
+        self.fc = nn.Sequential(
+            nn.Linear(512, 256),
+            nn.BatchNorm1d(256),
+            nn.ReLU(),
+            nn.Dropout(p=0.5),
+            nn.Linear(256, 128),
+            nn.BatchNorm1d(128),
+            nn.ReLU(),
+            nn.Dropout(p=0.5),
+            nn.Linear(128, num_classes),
+        )
+
+    def forward(self, x):
+        # Apply Transform_Net on input point cloud
+
+        trans_matrix = self.transform_net(x)
+        x = torch.bmm(x, trans_matrix)
+        x1 = self.edge_conv1(x)
+        x1 = x1.max(dim=-1, keepdim=False)[0] 
+        x2 = self.edge_conv2(x1)
+        x2 = x2.max(dim=-1, keepdim=False)[0]
+        x3 = self.edge_conv3(x2)
+        x3 = x3.max(dim=-1, keepdim=False)[0]
+        x4 = self.edge_conv4(x3)
+        x4 = x4.max(dim=-1, keepdim=False)[0]
+        x = torch.max(x4, dim=1, keepdim=True)[0]
+        x = self.fc(x.squeeze(1))
+        return x
\ No newline at end of file
diff --git a/dgcnn/shapenet_data.py b/dgcnn/shapenet_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cd999b81b4c4b07bb4f90af9feb410dc77de56a
--- /dev/null
+++ b/dgcnn/shapenet_data.py
@@ -0,0 +1,154 @@
+
+import json
+import os
+
+import numpy as np
+
+
+class ShapenetData(object):
+    """
+    The is the data loader for the ShapeNet dataset. Only for data segmentation, not for classification.
+    """
+    def __init__(self,
+                 root,
+                 npoints=1200,
+                 split='train',
+                 small_data=False,
+                 small_data_size=10,
+                 just_one_class=False,
+                 norm=True,
+                 data_augmentation=False
+                 ) -> None:
+        
+        self.npoints = npoints
+        self.root = root
+        self.split = split
+        self.small_data = small_data
+        self.small_data_size = small_data_size
+        self.just_one_class = just_one_class
+        self.norm = norm
+        self.data_augmentation = data_augmentation
+
+        # data operations
+        self.train_data_suffled_list_file = os.path.join(
+            self.root, 
+            'raw',
+            'train_test_split',
+            'shuffled_train_file_list.json')
+        self.test_data_suffled_list_file = os.path.join(
+            self.root, 
+            'raw',
+            'train_test_split',
+            'shuffled_test_file_list.json')
+        self.val_data_file = os.path.join(
+            self.root, 
+            'raw',
+            'train_test_split',
+            'shuffled_val_file_list.json')
+        
+
+        self.catfile = os.path.join(self.root, 
+                                    'raw',
+                                    'synsetoffset2category.txt')
+        self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
+                        'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
+                        'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
+                        'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
+                        'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
+        
+        self.cat = {}
+        with open(self.catfile, 'r') as f:
+            for line in f:
+                ls = line.strip().split()
+                self.cat[ls[0]] = ls[1]
+
+        self.train_file_list = self.get_list_shuffled_data(self.root, self.train_data_suffled_list_file)
+        self.test_file_list = self.get_list_shuffled_data(self.root, self.test_data_suffled_list_file)
+        self.val_data_file = self.get_list_shuffled_data(self.root, self.val_data_file)
+
+        # take the first 10 data for training
+        if self.small_data:
+            self.train_file_list = self.train_file_list[:self.small_data_size]
+            self.test_file_list = self.test_file_list[:self.small_data_size]
+            self.val_data_file = self.val_data_file[:self.small_data_size]
+        
+
+    def get_list_shuffled_data(self,root, json_file):
+       # read the json file and return the list of data
+        with open(json_file, 'r') as f:
+            data = json.load(f)
+
+        for i in range(len(data)):
+            data[i] = os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt'))
+
+        # get one class of data
+        # get the the number of the class airplane
+
+        if self.just_one_class:
+            data = [x for x in data if x.split('/')[-2] == self.cat['Airplane']]
+
+        return data
+    
+    def get_seg_classes(self, cat):
+        return self.seg_classes[cat]
+    
+    def get_class_names(self):
+        return list(self.cat.values())
+    
+    def get_all_names_of_classes(self):
+        return list(self.cat.keys())
+    
+    def normalize(self, pc):
+        centroid = np.mean(pc, axis=0)
+        pc = pc - centroid
+        m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
+        pc = pc / m
+        return pc
+        
+    # TODO: add the selection for a given class
+
+    def __getitem__(self, index):
+        if self.split == 'train':
+            point_set = np.loadtxt(self.train_file_list[index]).astype(np.float32)
+        elif self.split == 'test':
+            point_set = np.loadtxt(self.test_file_list[index]).astype(np.float32)
+        elif self.split == 'val':
+            point_set = np.loadtxt(self.val_data_file[index]).astype(np.float32)
+
+        # get just the points
+        point_set = point_set[:, 0:3]
+
+        # normalize the points
+        if self.norm:
+            point_set = self.normalize(point_set)
+
+        # choice = np.random.choice(len(point_set), self.npoints, replace=True)
+        # chose the first npoints
+        choice = np.arange(self.npoints)
+
+        point_set = point_set[choice, :]
+        point_set = point_set.astype(np.float32)
+
+        return point_set
+
+    def __len__(self):
+        if self.split == 'train':
+            return len(self.train_file_list)
+        elif self.split == 'test':
+            return len(self.test_file_list)
+        elif self.split == 'val':
+            return len(self.val_data_file)
+
+
+if __name__ == "__main__":
+
+  shapenet_data = ShapenetData(root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', split='train')
+#   print(shapenet_data.train_file_list)
+  print(shapenet_data.get_seg_classes('Car'))
+  print(shapenet_data.get_class_names())
+  print(shapenet_data.get_all_names_of_classes())
+
+  # get the first point cloud
+  point_cloud = shapenet_data[0]
+
+  print(point_cloud[:10])
\ No newline at end of file
diff --git a/dgcnn/transform_net_run.ipynb b/dgcnn/transform_net_run.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..fa3166bd79431ad1af980c707a58c6026be8c030
--- /dev/null
+++ b/dgcnn/transform_net_run.ipynb
@@ -0,0 +1,108 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Transform matrix shape:  torch.Size([1, 3, 3])\n"
+     ]
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torch.nn.init as init\n",
+    "\n",
+    "class Transform_Net(nn.Module):\n",
+    "    def __init__(self, k=3):\n",
+    "        super(Transform_Net, self).__init__()\n",
+    "        self.k = k\n",
+    "\n",
+    "        self.conv1 = nn.Conv1d(k, 64, 1)\n",
+    "        self.conv2 = nn.Conv1d(64, 128, 1)\n",
+    "        self.conv3 = nn.Conv1d(128, 1024, 1)\n",
+    "\n",
+    "        self.fc1 = nn.Linear(1024, 512)\n",
+    "        self.fc2 = nn.Linear(512, 256)\n",
+    "        self.fc3 = nn.Linear(256, k*k)\n",
+    "\n",
+    "        self.bn1 = nn.BatchNorm1d(64)\n",
+    "        self.bn2 = nn.BatchNorm1d(128)\n",
+    "        self.bn3 = nn.BatchNorm1d(1024)\n",
+    "        self.bn4 = nn.BatchNorm1d(512)\n",
+    "        self.bn5 = nn.BatchNorm1d(256)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        # Input shape: (batch_size, k, num_points)\n",
+    "        x = F.relu(self.bn1(self.conv1(x)))\n",
+    "        x = F.relu(self.bn2(self.conv2(x)))\n",
+    "        x = F.relu(self.bn3(self.conv3(x)))\n",
+    "\n",
+    "        x = torch.max(x, 2, keepdim=True)[0]\n",
+    "        x = x.view(-1, 1024)\n",
+    "\n",
+    "        x = F.relu(self.bn4(self.fc1(x)))\n",
+    "        x = F.relu(self.bn5(self.fc2(x)))\n",
+    "        x = self.fc3(x)\n",
+    "\n",
+    "        identity = torch.eye(self.k, dtype=x.dtype, device=x.device)\n",
+    "        transform = x.view(-1, self.k, self.k) + identity\n",
+    "\n",
+    "        return transform\n",
+    "    \n",
+    "\n",
+    "\n",
+    "# Create a random input tensor\n",
+    "batch_size = 1\n",
+    "num_points = 1024\n",
+    "k = 3\n",
+    "\n",
+    "# Generate a random point cloud of shape (batch_size, k, num_points)\n",
+    "input_tensor = torch.randn(batch_size, k, num_points)\n",
+    "\n",
+    "\n",
+    "# Initialize the Transform_Net module\n",
+    "transform_net = Transform_Net(k)\n",
+    "\n",
+    "transform_net.eval()\n",
+    "\n",
+    "# Forward pass the input tensor through the Transform_Net module\n",
+    "transform_matrix = transform_net(input_tensor)\n",
+    "\n",
+    "# Check the output tensor shape\n",
+    "print(\"Transform matrix shape: \", transform_matrix.shape)\n",
+    "\n",
+    "# Apply the transformation matrix to the input point cloud\n",
+    "input_tensor_transformed = torch.bmm(transform_matrix, input_tensor)\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.10"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}