From 81877145e1468ab4585687b15acf34f2a2c2b70d Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Thu, 20 Apr 2023 13:37:23 +0200
Subject: [PATCH] update

---
 cifar_example/cifar10_lightning.py            | 113 +++++
 cifar_example/cifar10_lightning_ver_2.py      |  92 ++--
 dgcnn/attention_usage.ipynb                   |  78 ++++
 dgcnn/get_size_of_dataset.py                  |   6 +-
 .../{ => jupyters}/edge_conv_layer_run.ipynb  |  25 ++
 dgcnn/jupyters/model10_vis.ipynb              | 304 +++++++++++++
 dgcnn/jupyters/my_shapenet_vis.ipynb          | 414 ++++++++++++++++++
 dgcnn/{ => jupyters}/transform_net_run.ipynb  |  50 +++
 dgcnn/my_models/model_shape_net.py            |   4 +-
 dgcnn/shapenet_data_dgcnn.py                  |  42 +-
 10 files changed, 1064 insertions(+), 64 deletions(-)
 create mode 100644 cifar_example/cifar10_lightning.py
 create mode 100644 dgcnn/attention_usage.ipynb
 rename dgcnn/{ => jupyters}/edge_conv_layer_run.ipynb (93%)
 create mode 100644 dgcnn/jupyters/model10_vis.ipynb
 create mode 100644 dgcnn/jupyters/my_shapenet_vis.ipynb
 rename dgcnn/{ => jupyters}/transform_net_run.ipynb (65%)

diff --git a/cifar_example/cifar10_lightning.py b/cifar_example/cifar10_lightning.py
new file mode 100644
index 0000000..bec2a49
--- /dev/null
+++ b/cifar_example/cifar10_lightning.py
@@ -0,0 +1,113 @@
+import os
+from torch import nn
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torchvision import transforms
+
+from pytorch_lightning import LightningModule, Trainer
+from torchvision.datasets import CIFAR10
+from pytorch_lightning.callbacks.progress import TQDMProgressBar
+from pytorch_lightning.loggers import CSVLogger
+
+# import modules
+from cifar_example.cifar_transformer_modules.my_transformer_layer import MyTransformerLayer
+from cifar_example.cifar_transformer_modules.embedding import Embedding
+
+# variables
+PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
+BATCH_SIZE = 256 if torch.cuda.is_available() else 64
+
+class CIFAR10LightningTransformer(LightningModule):
+    def __init__(self):
+        super().__init__()
+        self.embedding_size=64
+        self.criterion = torch.nn.CrossEntropyLoss()
+        self.embedding = Embedding(
+            patch_size=8, 
+            in_channels=3, 
+            out_channels=self.embedding_size, 
+            return_patches=False, 
+            extra_token=True
+            )
+        self.self_attention = MyTransformerLayer(d_model=self.embedding_size, nhead=16, dropout=0.3)
+        self.fc = nn.Linear(self.embedding_size, 10)
+    
+    def forward(self, x):
+        embedding = self.embedding(x)
+        context = self.self_attention(embedding)
+        # get the first token
+        context = context[:, 0, :]
+
+        # context = context.mean(dim=1)
+
+        # get the classification
+        context = self.fc(context)
+    
+        return context
+    
+    def accuracy(self, logits, y):
+        preds = torch.argmax(logits, dim=1)
+        return torch.sum(preds == y).item() / len(y)
+
+    def training_step(self, batch, batch_idx):
+        x, y = batch
+        logits = self(x)
+        loss = self.criterion(logits, y)
+        self.log('train_loss', loss)
+        self.log('train_acc', self.accuracy(logits, y))
+        return loss
+    
+    def validation_step(self, batch, batch_idx):
+        x, y = batch
+        logits = self(x)
+        loss = self.criterion(logits, y)
+        self.log('val_loss', loss)
+        self.log('val_acc', self.accuracy(logits, y))
+        return loss
+    
+    def test_step(self, batch, batch_idx):
+        x, y = batch
+        logits = self(x)
+        loss = self.criterion(logits, y)
+        self.log('test_loss', loss)
+        self.log('test_acc', self.accuracy(logits, y))
+        return loss
+    
+    def configure_optimizers(self):
+        return torch.optim.Adam(self.parameters(), lr=0.02)
+
+cifar_model = CIFAR10LightningTransformer()
+
+# train_ds = CIFAR10(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
+# train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
+# val_ds = CIFAR10(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())
+# val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
+# test_ds = CIFAR10(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())
+# test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)
+
+# get the train data
+train_ds = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
+# get test data
+test_ds = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
+
+# get the train loader
+train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)
+
+# get the test loader
+test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)
+
+# Initialize a trainer
+trainer = Trainer(
+    accelerator="auto",
+    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
+    max_epochs=10,
+    callbacks=[TQDMProgressBar(refresh_rate=20)],
+    logger=CSVLogger(save_dir="logs/"),
+)
+
+# Train the model ⚡
+# trainer.fit(cifar_model, train_loader)
+
+# Test the model
+trainer.test(cifar_model, dataloaders=test_loader)
\ No newline at end of file
diff --git a/cifar_example/cifar10_lightning_ver_2.py b/cifar_example/cifar10_lightning_ver_2.py
index c12ca99..c23b1d4 100644
--- a/cifar_example/cifar10_lightning_ver_2.py
+++ b/cifar_example/cifar10_lightning_ver_2.py
@@ -54,57 +54,57 @@ from cifar_example.cifar_transformer_modules.my_transformer_layer import MyTrans
 from cifar_example.cifar_transformer_modules.embedding import Embedding
 
 # create  resnet model for cifar10 classification
-class CIFAR10Model(pl.LightningModule):
-    def __init__(self):
-        super().__init__()
-        self.model = self.create_model()
+# class CIFAR10Model(pl.LightningModule):
+#     def __init__(self):
+#         super().__init__()
+#         self.model = self.create_model()
 
-    def forward(self, x):
-        return self.model(x)
+#     def forward(self, x):
+#         return self.model(x)
     
-    def create_model(self):
-        model = resnet18(pretrained=False, num_classes=10)
-        model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
-        model.maxpool = nn.Identity()
-        return model
-
-    def training_step(self, batch, batch_idx):
-        x, y = batch
-        y_hat = self(x)
-        loss = F.cross_entropy(y_hat, y)
-        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
-        # log accuracy
-        self.log("train_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
-        return loss
-
-    def validation_step(self, batch, batch_idx):
-        x, y = batch
-        y_hat = self(x)
-        loss = F.cross_entropy(y_hat, y)
-        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
-        # log accuracy
-        self.log("val_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
-        return loss
-
-    def test_step(self, batch, batch_idx):
-        x, y = batch
-        y_hat = self(x)
-        loss = F.cross_entropy(y_hat, y)
-        self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
-        # log accuracy
-        self.log("test_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
-        return loss
-
-    def configure_optimizers(self):
-        return torch.optim.Adam(self.parameters(), lr=0.001)
-
-    def accuracy(self, y_hat, y):
-        preds = torch.argmax(y_hat, dim=1)
-        return (preds == y).float().mean()
+#     def create_model(self):
+#         model = resnet18(pretrained=False, num_classes=10)
+#         model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+#         model.maxpool = nn.Identity()
+#         return model
+
+#     def training_step(self, batch, batch_idx):
+#         x, y = batch
+#         y_hat = self(x)
+#         loss = F.cross_entropy(y_hat, y)
+#         self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
+#         # log accuracy
+#         self.log("train_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
+#         return loss
+
+#     def validation_step(self, batch, batch_idx):
+#         x, y = batch
+#         y_hat = self(x)
+#         loss = F.cross_entropy(y_hat, y)
+#         self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
+#         # log accuracy
+#         self.log("val_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
+#         return loss
+
+#     def test_step(self, batch, batch_idx):
+#         x, y = batch
+#         y_hat = self(x)
+#         loss = F.cross_entropy(y_hat, y)
+#         self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
+#         # log accuracy
+#         self.log("test_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
+#         return loss
+
+#     def configure_optimizers(self):
+#         return torch.optim.Adam(self.parameters(), lr=0.001)
+
+#     def accuracy(self, y_hat, y):
+#         preds = torch.argmax(y_hat, dim=1)
+#         return (preds == y).float().mean()
     
 
 # create  resnet model for cifar10 classification
-# class CIFAR10Model(pl.LightningModule):
+class CIFAR10Model(pl.LightningModule):
     def __init__(self):
         super().__init__()
         self.embedding_size=64
diff --git a/dgcnn/attention_usage.ipynb b/dgcnn/attention_usage.ipynb
new file mode 100644
index 0000000..4c7e038
--- /dev/null
+++ b/dgcnn/attention_usage.ipynb
@@ -0,0 +1,78 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "AssertionError",
+     "evalue": "embed_dim must be divisible by num_heads",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
+      "\u001b[1;32m/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb Cell 1\u001b[0m in \u001b[0;36m2\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a>\u001b[0m \u001b[39m# show how to use self attention\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a>\u001b[0m x \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(\u001b[39m2\u001b[39m, \u001b[39m1024\u001b[39m, \u001b[39m3\u001b[39m)\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=26'>27</a>\u001b[0m self_attention \u001b[39m=\u001b[39m SelfAttention(\u001b[39m3\u001b[39;49m, \u001b[39m2\u001b[39;49m, \u001b[39m0.1\u001b[39;49m)\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a>\u001b[0m out \u001b[39m=\u001b[39m self_attention(x)\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=28'>29</a>\u001b[0m \u001b[39mprint\u001b[39m(out\u001b[39m.\u001b[39mshape)\n",
+      "\u001b[1;32m/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb Cell 1\u001b[0m in \u001b[0;36m1\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_heads \u001b[39m=\u001b[39m num_heads\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11'>12</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdropout \u001b[39m=\u001b[39m dropout\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=12'>13</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mself_attention \u001b[39m=\u001b[39m MultiheadAttention(in_channels, num_heads\u001b[39m=\u001b[39;49mnum_heads, dropout\u001b[39m=\u001b[39;49mdropout)\n",
+      "File \u001b[0;32m~/.local/lib/python3.8/site-packages/torch/nn/modules/activation.py:960\u001b[0m, in \u001b[0;36mMultiheadAttention.__init__\u001b[0;34m(self, embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype)\u001b[0m\n\u001b[1;32m    958\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbatch_first \u001b[39m=\u001b[39m batch_first\n\u001b[1;32m    959\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mhead_dim \u001b[39m=\u001b[39m embed_dim \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m num_heads\n\u001b[0;32m--> 960\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mhead_dim \u001b[39m*\u001b[39m num_heads \u001b[39m==\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membed_dim, \u001b[39m\"\u001b[39m\u001b[39membed_dim must be divisible by num_heads\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    962\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_qkv_same_embed_dim:\n\u001b[1;32m    963\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mq_proj_weight \u001b[39m=\u001b[39m Parameter(torch\u001b[39m.\u001b[39mempty((embed_dim, embed_dim), \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfactory_kwargs))\n",
+      "\u001b[0;31mAssertionError\u001b[0m: embed_dim must be divisible by num_heads"
+     ]
+    }
+   ],
+   "source": [
+    "# import pytorch attention\n",
+    "from torch.nn import MultiheadAttention\n",
+    "from torch import nn\n",
+    "import torch\n",
+    "\n",
+    "# implement self attention\n",
+    "class SelfAttention(nn.Module):\n",
+    "    def __init__(self, in_channels, num_heads, dropout):\n",
+    "        super(SelfAttention, self).__init__()\n",
+    "        self.in_channels = in_channels\n",
+    "        self.num_heads = num_heads\n",
+    "        self.dropout = dropout\n",
+    "        self.self_attention = MultiheadAttention(in_channels, num_heads=num_heads, dropout=dropout)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        batch_size = x.size(0)\n",
+    "        num_points = x.size(2)\n",
+    "        x = x.view(batch_size, -1, num_points)\n",
+    "        x = x.permute(1, 0, 2)\n",
+    "        out, attn = self.self_attention(x, x, x)\n",
+    "        out = out.permute(1, 0, 2)\n",
+    "        out = out.view(batch_size, -1, num_points)\n",
+    "        return out\n",
+    "\n",
+    "# show how to use self attention\n",
+    "x = torch.randn(2, 1024, 3)\n",
+    "self_attention = SelfAttention(3, 3, 0.1)\n",
+    "out = self_attention(x)\n",
+    "print(out.shape)\n",
+    "\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.10"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/dgcnn/get_size_of_dataset.py b/dgcnn/get_size_of_dataset.py
index af7cf04..1ee6847 100644
--- a/dgcnn/get_size_of_dataset.py
+++ b/dgcnn/get_size_of_dataset.py
@@ -7,7 +7,7 @@ shapenet_data_train = ShapenetDataDgcnn(
       return_cls_label=True,
       small_data=False,
       small_data_size=1000,
-      just_four_classes=True,
+      num_classes=True,
       split='train',
       norm=True
       )
@@ -18,7 +18,7 @@ shapenet_data_test = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=False,
         small_data_size=1000,
-        just_four_classes=True,
+        num_classes=True,
         split='test',
         norm=True
         )
@@ -29,7 +29,7 @@ shapenet_data_val = ShapenetDataDgcnn(
         return_cls_label=True,
         small_data=False,
         small_data_size=1000,
-        just_four_classes=True,
+        num_classes=True,
         split='val',
         norm=True
         )
diff --git a/dgcnn/edge_conv_layer_run.ipynb b/dgcnn/jupyters/edge_conv_layer_run.ipynb
similarity index 93%
rename from dgcnn/edge_conv_layer_run.ipynb
rename to dgcnn/jupyters/edge_conv_layer_run.ipynb
index b27286d..6792c92 100644
--- a/dgcnn/edge_conv_layer_run.ipynb
+++ b/dgcnn/jupyters/edge_conv_layer_run.ipynb
@@ -207,6 +207,31 @@
     "\n",
     "print(\"neighbors:\", neighbors.shape)"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "torch.Size([32, 3])\n"
+     ]
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "\n",
+    "# Assuming you have a tensor 'x' with shape (8, 32, 3)\n",
+    "x = torch.randn(8, 32, 3)\n",
+    "\n",
+    "# Remove the batch dimension\n",
+    "x_no_batch = x[1,:,:]\n",
+    "\n",
+    "print(x_no_batch.shape)"
+   ]
   }
  ],
  "metadata": {
diff --git a/dgcnn/jupyters/model10_vis.ipynb b/dgcnn/jupyters/model10_vis.ipynb
new file mode 100644
index 0000000..5d5b91c
--- /dev/null
+++ b/dgcnn/jupyters/model10_vis.ipynb
@@ -0,0 +1,304 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1.13.1+cu117\n"
+     ]
+    }
+   ],
+   "source": [
+    "import os\n",
+    "import torch\n",
+    "os.environ['TORCH'] = torch.__version__\n",
+    "print(torch.__version__)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from glob import glob\n",
+    "from PIL import Image\n",
+    "from tqdm.auto import tqdm\n",
+    "\n",
+    "import wandb\n",
+    "\n",
+    "import torch\n",
+    "import torch.nn.functional as F\n",
+    "\n",
+    "import numpy as np\n",
+    "import networkx as nx\n",
+    "import matplotlib.pyplot as plt\n",
+    "from pyvis.network import Network\n",
+    "from mpl_toolkits.mplot3d import Axes3D\n",
+    "\n",
+    "import torch_geometric.transforms as T\n",
+    "from torch_geometric.datasets import ModelNet\n",
+    "from torch_geometric.loader import DataLoader\n",
+    "from torch_geometric.utils import to_networkx\n",
+    "from torch_geometric.nn import knn_graph, radius_graph"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "wandb version 0.14.2 is available!  To upgrade, please run:\n",
+       " $ pip install wandb --upgrade"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Tracking run with wandb version 0.13.10"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Run data is saved locally in <code>/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/jupyters/wandb/run-20230414_091510-h2iclcgp</code>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Syncing run <strong><a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp' target=\"_blank\">modelnet10/train/sampling-comparison</a></strong> to <a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       " View project at <a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud</a>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       " View run at <a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp</a>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "wandb_project = \"pyg-point-cloud\" #@param {\"type\": \"string\"}\n",
+    "wandb_run_name = \"modelnet10/train/sampling-comparison\" #@param {\"type\": \"string\"}\n",
+    "\n",
+    "wandb.init(project=wandb_project, entity=\"maciej-wielgosz-nibio\", name=wandb_run_name, job_type=\"eda\")\n",
+    "\n",
+    "# Set experiment configs to be synced with wandb\n",
+    "config = wandb.config\n",
+    "config.display_sample = 2048  #@param {type:\"slider\", min:256, max:4096, step:16}\n",
+    "config.modelnet_dataset_alias = \"ModelNet10\" #@param [\"ModelNet10\", \"ModelNet40\"] {type:\"raw\"}\n",
+    "\n",
+    "# Classes for ModelNet10 and ModelNet40\n",
+    "categories = sorted([\n",
+    "    x.split(os.sep)[-2]\n",
+    "    for x in glob(os.path.join(\n",
+    "        config.modelnet_dataset_alias, \"raw\", '*', ''\n",
+    "    ))\n",
+    "])\n",
+    "\n",
+    "\n",
+    "config.categories = categories"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pre_transform = T.NormalizeScale()\n",
+    "transform = T.SamplePoints(config.display_sample)\n",
+    "train_dataset = ModelNet(\n",
+    "    root=config.modelnet_dataset_alias,\n",
+    "    name=config.modelnet_dataset_alias[-2:],\n",
+    "    train=True,\n",
+    "    transform=transform,\n",
+    "    pre_transform=pre_transform\n",
+    ")\n",
+    "val_dataset = ModelNet(\n",
+    "    root=config.modelnet_dataset_alias,\n",
+    "    name=config.modelnet_dataset_alias[-2:],\n",
+    "    train=False,\n",
+    "    transform=transform,\n",
+    "    pre_transform=pre_transform\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|██████████| 20/20 [00:01<00:00, 11.47it/s]\n",
+      "100%|██████████| 100/100 [00:02<00:00, 36.10it/s]\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       " View run <strong style=\"color:#cdcd00\">modelnet10/train/sampling-comparison</strong> at: <a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp</a><br/>Synced 5 W&B file(s), 3 media file(s), 103 artifact file(s) and 0 other file(s)"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Find logs at: <code>./wandb/run-20230414_091510-h2iclcgp/logs</code>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "table = wandb.Table(columns=[\"Model\", \"Class\", \"Split\"])\n",
+    "category_dict = {key: 0 for key in config.categories}\n",
+    "for idx in tqdm(range(len(train_dataset[:20]))):\n",
+    "    point_cloud = wandb.Object3D(train_dataset[idx].pos.numpy())\n",
+    "    category = config.categories[int(train_dataset[idx].y.item())]\n",
+    "    category_dict[category] += 1\n",
+    "    table.add_data(\n",
+    "        point_cloud,\n",
+    "        category,\n",
+    "        \"Train\"\n",
+    "    )\n",
+    "\n",
+    "data = [[key, category_dict[key]] for key in config.categories]\n",
+    "wandb.log({\n",
+    "    f\"{config.modelnet_dataset_alias} Class-Frequency Distribution\" : wandb.plot.bar(\n",
+    "        wandb.Table(data=data, columns = [\"Class\", \"Frequency\"]),\n",
+    "        \"Class\", \"Frequency\",\n",
+    "        title=f\"{config.modelnet_dataset_alias} Class-Frequency Distribution\"\n",
+    "    )\n",
+    "})\n",
+    "\n",
+    "table = wandb.Table(columns=[\"Model\", \"Class\", \"Split\"])\n",
+    "category_dict = {key: 0 for key in config.categories}\n",
+    "for idx in tqdm(range(len(val_dataset[:100]))):\n",
+    "    point_cloud = wandb.Object3D(val_dataset[idx].pos.numpy())\n",
+    "    category = config.categories[int(val_dataset[idx].y.item())]\n",
+    "    category_dict[category] += 1\n",
+    "    table.add_data(\n",
+    "        point_cloud,\n",
+    "        category,\n",
+    "        \"Test\"\n",
+    "    )\n",
+    "wandb.log({config.modelnet_dataset_alias: table})\n",
+    "\n",
+    "data = [[key, category_dict[key]] for key in config.categories]\n",
+    "wandb.log({\n",
+    "    f\"{config.modelnet_dataset_alias} Class-Frequency Distribution\" : wandb.plot.bar(\n",
+    "        wandb.Table(data=data, columns = [\"Class\", \"Frequency\"]),\n",
+    "        \"Class\", \"Frequency\",\n",
+    "        title=f\"{config.modelnet_dataset_alias} Class-Frequency Distribution\"\n",
+    "    )\n",
+    "})\n",
+    "\n",
+    "wandb.finish()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.10"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/dgcnn/jupyters/my_shapenet_vis.ipynb b/dgcnn/jupyters/my_shapenet_vis.ipynb
new file mode 100644
index 0000000..bbd92c6
--- /dev/null
+++ b/dgcnn/jupyters/my_shapenet_vis.ipynb
@@ -0,0 +1,414 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import wandb\n",
+    "import random\n",
+    "random.seed(10)\n",
+    "\n",
+    "\n",
+    "wandb_project = \"my_shapenet_vis1\" \n",
+    "wandb_run_name = \"my_shapenet_vis_showcase\" \n",
+    "WANDB_NOTEBOOK_NAME = \"my_shapenet_vis.ipynb\"\n",
+    "\n",
+    "wandb.init(project=wandb_project, entity=\"maciej-wielgosz-nibio\", name=wandb_run_name, job_type=\"eda\")\n",
+    "\n",
+    "# load shape net data\n",
+    "import sys\n",
+    "sys.path.append('/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn')\n",
+    "from shapenet_data_dgcnn import ShapenetDataDgcnn\n",
+    "shapenet_data = ShapenetDataDgcnn(\n",
+    "    root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',\n",
+    "    npoints=512,\n",
+    "    return_cls_label=True,\n",
+    "    small_data=True,\n",
+    "    small_data_size=10,\n",
+    "    num_classes=1,\n",
+    "    data_augmentation=False,\n",
+    "    split='train',\n",
+    "    norm=True\n",
+    "    )\n",
+    "\n",
+    "# get first data point\n",
+    "data = shapenet_data[1]\n",
+    "data[0]\n",
+    "\n",
+    "# print(data[1])\n",
+    "\n",
+    "# find how many different values are in data[1]\n",
+    "import numpy as np\n",
+    "uv = np.unique(data[1])\n",
+    "\n",
+    "# generte random RGB colors for each class\n",
+    "import random\n",
+    "colors = []\n",
+    "for i in range(50):\n",
+    "    colors.append([random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)])\n",
+    "\n",
+    "colors_per_point = [colors[data[1][i]] for i in range(len(data[1]))]\n",
+    "\n",
+    "# create a point cloud\n",
+    "points_rgb = np.array([[p[0], p[1], p[2], c[0], c[1], c[2]] for p, c in zip(data[0], colors_per_point)])\n",
+    "\n",
+    "point_cloud = wandb.Object3D(\n",
+    "    {\n",
+    "        \"type\": \"lidar/beta\",\n",
+    "        \"points\": points_rgb\n",
+    "    }\n",
+    ")\n",
+    "\n",
+    "# show point cloud in wandb\n",
+    "wandb.log({\"point_cloud\": point_cloud})\n",
+    "\n",
+    "\n",
+    "# create a point cloud points_rgb_0 by choosing only points with class 0\n",
+    "points_rgb_0 = points_rgb[data[1] == 0]\n",
+    "\n",
+    "print(points_rgb_0.shape)\n",
+    "\n",
+    "point_cloud_0 = wandb.Object3D(\n",
+    "    {\n",
+    "        \"type\": \"lidar/beta\",\n",
+    "        \"points\": points_rgb_0\n",
+    "    }\n",
+    ")\n",
+    "\n",
+    "# create a point cloud points_rgb_1 by choosing only points with class 1\n",
+    "points_rgb_1 = points_rgb[data[1] == 1]\n",
+    "\n",
+    "print(points_rgb_1.shape)\n",
+    "\n",
+    "point_cloud_1 = wandb.Object3D(\n",
+    "    {\n",
+    "        \n",
+    "        \"type\": \"lidar/beta\",\n",
+    "        \"points\": points_rgb_1\n",
+    "    }\n",
+    ")\n",
+    "\n",
+    "\n",
+    "# show point cloud in wandb\n",
+    "wandb.log({\"point_cloud_0\": point_cloud_0})\n",
+    "wandb.log({\"point_cloud_1\": point_cloud_1})\n",
+    "\n",
+    "# get the histogram for the data[1] (class labels)\n",
+    "labels_hist = {}\n",
+    "for label in data[1]:\n",
+    "    if label in labels_hist:\n",
+    "        labels_hist[label] += 1\n",
+    "    else:\n",
+    "        labels_hist[label] = 1\n",
+    "\n",
+    "# create a table for the histogram\n",
+    "table = wandb.Table(columns=[\"Class\", \"frequency\"])\n",
+    "\n",
+    "# write the class histogram to wandb\n",
+    "for class_name, count in labels_hist.items():\n",
+    "    table.add_data(class_name, count)\n",
+    "\n",
+    "wandb.log({\"class freq\": wandb.plot.bar(table, \"Class\", \"frequency\")})\n",
+    "\n",
+    "wandb.finish()\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "Finishing last run (ID:vrw80dza) before initializing another..."
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       " View run <strong style=\"color:#cdcd00\">my_shapenet_vis_showcase</strong> at: <a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/vrw80dza' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/vrw80dza</a><br/>Synced 5 W&B file(s), 1 media file(s), 31 artifact file(s) and 0 other file(s)"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Find logs at: <code>./wandb/run-20230417_115825-vrw80dza/logs</code>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Successfully finished last run (ID:vrw80dza). Initializing new run:<br/>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Tracking run with wandb version 0.14.2"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Run data is saved locally in <code>/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/jupyters/wandb/run-20230417_120056-f4w7yi7s</code>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "Syncing run <strong><a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/f4w7yi7s' target=\"_blank\">my_shapenet_vis_showcase</a></strong> to <a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       " View project at <a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1</a>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       " View run at <a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/f4w7yi7s' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/f4w7yi7s</a>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "({0: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d575ea30>,\n",
+       "  1: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd857cd0>,\n",
+       "  2: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd91c8e0>,\n",
+       "  3: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd91caf0>,\n",
+       "  4: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57509a0>,\n",
+       "  5: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d578bdf0>,\n",
+       "  6: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d36b7d30>,\n",
+       "  7: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57fdca0>,\n",
+       "  8: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57d6a60>,\n",
+       "  9: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd7e4250>},\n",
+       " {0: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd7e4490>,\n",
+       "  1: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd7dbb50>,\n",
+       "  2: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57f70a0>,\n",
+       "  3: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5789340>,\n",
+       "  4: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57894c0>,\n",
+       "  5: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5777310>,\n",
+       "  6: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d58012b0>,\n",
+       "  7: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d575e910>,\n",
+       "  8: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5790040>,\n",
+       "  9: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd84a5e0>},\n",
+       " {0: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd84a700>,\n",
+       "  1: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd857eb0>,\n",
+       "  2: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd857e80>,\n",
+       "  3: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750d60>,\n",
+       "  4: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750f40>,\n",
+       "  5: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750820>,\n",
+       "  6: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750c70>,\n",
+       "  7: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750cd0>,\n",
+       "  8: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d578bfa0>,\n",
+       "  9: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d578b700>})"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import wandb\n",
+    "import random\n",
+    "import numpy as np\n",
+    "random.seed(10)\n",
+    "\n",
+    "\n",
+    "wandb_project = \"my_shapenet_vis1\" \n",
+    "wandb_run_name = \"my_shapenet_vis_showcase\" \n",
+    "WANDB_NOTEBOOK_NAME = \"my_shapenet_vis.ipynb\"\n",
+    "\n",
+    "run = wandb.init(project=wandb_project, entity=\"maciej-wielgosz-nibio\", name=wandb_run_name, job_type=\"eda\")\n",
+    "\n",
+    "# load shape net data\n",
+    "import sys\n",
+    "sys.path.append('/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn')\n",
+    "from shapenet_data_dgcnn import ShapenetDataDgcnn\n",
+    "shapenet_data = ShapenetDataDgcnn(\n",
+    "    root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',\n",
+    "    npoints=512,\n",
+    "    return_cls_label=True,\n",
+    "    small_data=True,\n",
+    "    small_data_size=10,\n",
+    "    num_classes=1,\n",
+    "    data_augmentation=False,\n",
+    "    split='train',\n",
+    "    norm=True\n",
+    "    )\n",
+    "\n",
+    "# create a function which will create a point cloud for each class\n",
+    "def create_point_clouds(data, pred):\n",
+    "    data_point_clouds = {}\n",
+    "    pred_point_clouds = {}\n",
+    "    diff_point_clouds = {}\n",
+    "\n",
+    "    colors = []\n",
+    "    for i in range(50):\n",
+    "        colors.append([random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)])\n",
+    "\n",
+    "    for i in range(len(data)):\n",
+    "        colors_per_point = [colors[data[i][1][j]] for j in range(len(data[i][1]))]\n",
+    "        points_rgb = np.array([[p[0], p[1], p[2], c[0], c[1], c[2]] for p, c in zip(data[i][0], colors_per_point)])\n",
+    "        data_point_clouds[i] = wandb.Object3D(\n",
+    "            {\n",
+    "                \"type\": \"lidar/beta\",\n",
+    "                \"points\": points_rgb\n",
+    "            }\n",
+    "        )\n",
+    "\n",
+    "    for i in range(len(pred)):\n",
+    "        colors_per_point = [colors[pred[i][1][j]] for j in range(len(pred[i][1]))]\n",
+    "        points_rgb = np.array([[p[0], p[1], p[2], c[0], c[1], c[2]] for p, c in zip(pred[i][0], colors_per_point)])\n",
+    "        pred_point_clouds[i] = wandb.Object3D(\n",
+    "            {\n",
+    "                \"type\": \"lidar/beta\",\n",
+    "                \"points\": points_rgb\n",
+    "            }\n",
+    "        )\n",
+    "    \n",
+    "    for i in range(len(data)):\n",
+    "        diff_point_clouds[i] = data[i][0] - pred[i][0] + 1\n",
+    "        diff_point_clouds[i] = wandb.Object3D(\n",
+    "            {\n",
+    "                \"type\": \"lidar/beta\",\n",
+    "                \"points\": diff_point_clouds[i],\n",
+    "                \n",
+    "            }\n",
+    "        )\n",
+    "        \n",
+    "        # wandb.log({\"point_cloud_\" + str(i): point_clouds[i]})\n",
+    "        # table.add_data(point_clouds[i], point_clouds[i])\n",
+    "    table_data = [[data_point_clouds[i], pred_point_clouds[i], diff_point_clouds[i]] for i in range(len(data_point_clouds))]\n",
+    "    table = wandb.Table(data=table_data, columns=[\"gt\", \"pred\", \"diff\"])\n",
+    "\n",
+    "    # show the table in wandb\n",
+    "    run.log({\"point_clouds table \": table})\n",
+    "\n",
+    "    return data_point_clouds, pred_point_clouds, diff_point_clouds\n",
+    "\n",
+    "\n",
+    "# create a function which will create a histogram for each class\n",
+    "def create_histograms(data):\n",
+    "    labels_hist = {}\n",
+    "    for i in range(len(data)):\n",
+    "        labels_hist[i] = {}\n",
+    "        for label in data[i][1]:\n",
+    "            if label in labels_hist[i]:\n",
+    "                labels_hist[i][label] += 1\n",
+    "            else:\n",
+    "                labels_hist[i][label] = 1\n",
+    " \n",
+    "\n",
+    "    # create a table for the histogram\n",
+    "    table = wandb.Table(columns=[\"Class\", \"frequency\"])\n",
+    "\n",
+    "    # write the class histogram to wandb\n",
+    "    for class_name, count in labels_hist.items():\n",
+    "        table.add_data(class_name, count)\n",
+    "\n",
+    "    wandb.log({\"class freq\": wandb.plot.bar(table, \"Class\", \"frequency\")})\n",
+    "    \n",
+    "\n",
+    "# use the functions to create point clouds and histograms\n",
+    "create_point_clouds(shapenet_data, shapenet_data)\n",
+    "\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.10"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/dgcnn/transform_net_run.ipynb b/dgcnn/jupyters/transform_net_run.ipynb
similarity index 65%
rename from dgcnn/transform_net_run.ipynb
rename to dgcnn/jupyters/transform_net_run.ipynb
index fa3166b..213d882 100644
--- a/dgcnn/transform_net_run.ipynb
+++ b/dgcnn/jupyters/transform_net_run.ipynb
@@ -81,6 +81,56 @@
     "# Apply the transformation matrix to the input point cloud\n",
     "input_tensor_transformed = torch.bmm(transform_matrix, input_tensor)\n"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Ground truth:  tensor([[1, 0],\n",
+      "        [0, 1],\n",
+      "        [1, 0],\n",
+      "        [0, 1],\n",
+      "        [1, 0],\n",
+      "        [0, 1]])\n",
+      "Predicted:  tensor([[1, 0],\n",
+      "        [0, 1],\n",
+      "        [0, 1],\n",
+      "        [0, 1],\n",
+      "        [1, 0],\n",
+      "        [0, 1]])\n",
+      "Jaccard index:  tensor(0.7500)\n"
+     ]
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "import torchmetrics\n",
+    "\n",
+    "# Define the Jaccard index (IoU) metric\n",
+    "jaccard_index = torchmetrics.JaccardIndex(num_classes=2, average=\"macro\", task=\"binary\")\n",
+    "\n",
+    "# Example ground truth and predicted labels\n",
+    "ground_truth = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.long)\n",
+    "predicted    = torch.tensor([0, 1, 1, 1, 0, 1], dtype=torch.long)\n",
+    "\n",
+    "# One-hot encode the ground truth and predicted labels\n",
+    "ground_truth_one_hot = torch.nn.functional.one_hot(ground_truth, num_classes=2)\n",
+    "predicted_one_hot = torch.nn.functional.one_hot(predicted, num_classes=2)\n",
+    "\n",
+    "print(\"Ground truth: \", ground_truth_one_hot)\n",
+    "print(\"Predicted: \", predicted_one_hot)\n",
+    "\n",
+    "\n",
+    "# Compute the Jaccard index\n",
+    "# jaccard_index = jaccard_index(predicted_one_hot, ground_truth_one_hot)\n",
+    "jaccard_index = jaccard_index(predicted, ground_truth)\n",
+    "print(\"Jaccard index: \", jaccard_index)"
+   ]
   }
  ],
  "metadata": {
diff --git a/dgcnn/my_models/model_shape_net.py b/dgcnn/my_models/model_shape_net.py
index 3593570..fc275ac 100644
--- a/dgcnn/my_models/model_shape_net.py
+++ b/dgcnn/my_models/model_shape_net.py
@@ -7,7 +7,7 @@ from my_models.edge_conv_new import EdgeConvNew
 
 
 class DgcnShapeNet(nn.Module):
-    def __init__(self, seg_num_all):
+    def __init__(self, seg_num_all, num_classes):
         super(DgcnShapeNet, self).__init__()
         self.seg_num_all = seg_num_all
         self.transform_net = Transform_Net()
@@ -25,7 +25,7 @@ class DgcnShapeNet(nn.Module):
         self.conv6 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
                                    self.bn6,
                                    nn.LeakyReLU(negative_slope=0.2))
-        self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
+        self.conv7 = nn.Sequential(nn.Conv1d(num_classes, 64, kernel_size=1, bias=False),
                                    self.bn7,
                                    nn.LeakyReLU(negative_slope=0.2))
         self.conv8 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False),
diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py
index abf9ba9..d518a63 100644
--- a/dgcnn/shapenet_data_dgcnn.py
+++ b/dgcnn/shapenet_data_dgcnn.py
@@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object):
                  small_data=False,
                  small_data_size=10,
                  return_cls_label=False,
-                 just_four_classes=False,
+                 num_classes=1, # None - all classes (50), 1 - one class, 2 - two classes, max 4
                  norm=False,
                  augmnetation=False,
                  data_augmentation=False
@@ -29,7 +29,7 @@ class ShapenetDataDgcnn(object):
         self.small_data = small_data
         self.small_data_size = small_data_size
         self.return_cls_label = return_cls_label
-        self.just_four_classes = just_four_classes
+        self.num_classes = num_classes
         self.norm = norm
         self.augmnetation = augmnetation
         self.data_augmentation = data_augmentation
@@ -90,12 +90,28 @@ class ShapenetDataDgcnn(object):
         # get one class of data
         # get the the number of the class airplane
 
-        if self.just_four_classes:
-            out_data = [x for x in out_data if x.split('/')[-2] in [
+        if self.num_classes is not None:
+            if self.num_classes == 1:
+                out_data = [x for x in out_data if x.split('/')[-2] in [
+                    self.cat['Airplane']
+                    ]]
+            elif self.num_classes == 2:
+                out_data = [x for x in out_data if x.split('/')[-2] in [
+                self.cat['Airplane'],
+                self.cat['Lamp']
+                ]]
+            elif self.num_classes == 3:
+                out_data = [x for x in out_data if x.split('/')[-2] in [
+                self.cat['Airplane'],
+                self.cat['Lamp'],
+                self.cat['Car']
+                ]]
+            elif self.num_classes == 4:
+                out_data = [x for x in out_data if x.split('/')[-2] in [
                 self.cat['Airplane'],
                 self.cat['Lamp'],
-                self.cat['Chair'],
-                self.cat['Table'],
+                self.cat['Car'],
+                self.cat['Chair']
                 ]]
         
         return out_data
@@ -207,8 +223,6 @@ class ShapenetDataDgcnn(object):
             point_set = self.rotate_pointcloud(point_set)
             point_set = self.translate_pointcloud(point_set)
 
-        
-
         choice = np.random.choice(len(point_set), self.npoints, replace=True)
 
         point_set = point_set[choice, :]
@@ -227,16 +241,18 @@ class ShapenetDataDgcnn(object):
             class_name = self.val_data_file[index].split('/')[-2]
 
         # apply the mapper
-        if self.just_four_classes:
-            class_name = self.class_mapper_4_classes(class_name)
-        else:
-            class_name = self.class_mapper(class_name)
+        # if self.num_classes:
+        #     class_name = self.class_mapper_4_classes(class_name)
+        # else:
+        #     class_name = self.class_mapper(class_name)
 
+        class_name = self.class_mapper(class_name)
+        
         # convert the class name to a number
         class_name = np.array(class_name, dtype=np.int64)
 
         # map to tensor
-        class_name = torch.from_numpy(class_name)
+        # class_name = torch.from_numpy(class_name)
 
         if self.return_cls_label:
             return point_set, labels, class_name
-- 
GitLab