From 81877145e1468ab4585687b15acf34f2a2c2b70d Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Thu, 20 Apr 2023 13:37:23 +0200 Subject: [PATCH] update --- cifar_example/cifar10_lightning.py | 113 +++++ cifar_example/cifar10_lightning_ver_2.py | 92 ++-- dgcnn/attention_usage.ipynb | 78 ++++ dgcnn/get_size_of_dataset.py | 6 +- .../{ => jupyters}/edge_conv_layer_run.ipynb | 25 ++ dgcnn/jupyters/model10_vis.ipynb | 304 +++++++++++++ dgcnn/jupyters/my_shapenet_vis.ipynb | 414 ++++++++++++++++++ dgcnn/{ => jupyters}/transform_net_run.ipynb | 50 +++ dgcnn/my_models/model_shape_net.py | 4 +- dgcnn/shapenet_data_dgcnn.py | 42 +- 10 files changed, 1064 insertions(+), 64 deletions(-) create mode 100644 cifar_example/cifar10_lightning.py create mode 100644 dgcnn/attention_usage.ipynb rename dgcnn/{ => jupyters}/edge_conv_layer_run.ipynb (93%) create mode 100644 dgcnn/jupyters/model10_vis.ipynb create mode 100644 dgcnn/jupyters/my_shapenet_vis.ipynb rename dgcnn/{ => jupyters}/transform_net_run.ipynb (65%) diff --git a/cifar_example/cifar10_lightning.py b/cifar_example/cifar10_lightning.py new file mode 100644 index 0000000..bec2a49 --- /dev/null +++ b/cifar_example/cifar10_lightning.py @@ -0,0 +1,113 @@ +import os +from torch import nn +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import transforms + +from pytorch_lightning import LightningModule, Trainer +from torchvision.datasets import CIFAR10 +from pytorch_lightning.callbacks.progress import TQDMProgressBar +from pytorch_lightning.loggers import CSVLogger + +# import modules +from cifar_example.cifar_transformer_modules.my_transformer_layer import MyTransformerLayer +from cifar_example.cifar_transformer_modules.embedding import Embedding + +# variables +PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") +BATCH_SIZE = 256 if torch.cuda.is_available() else 64 + +class CIFAR10LightningTransformer(LightningModule): + def __init__(self): + super().__init__() + self.embedding_size=64 + self.criterion = torch.nn.CrossEntropyLoss() + self.embedding = Embedding( + patch_size=8, + in_channels=3, + out_channels=self.embedding_size, + return_patches=False, + extra_token=True + ) + self.self_attention = MyTransformerLayer(d_model=self.embedding_size, nhead=16, dropout=0.3) + self.fc = nn.Linear(self.embedding_size, 10) + + def forward(self, x): + embedding = self.embedding(x) + context = self.self_attention(embedding) + # get the first token + context = context[:, 0, :] + + # context = context.mean(dim=1) + + # get the classification + context = self.fc(context) + + return context + + def accuracy(self, logits, y): + preds = torch.argmax(logits, dim=1) + return torch.sum(preds == y).item() / len(y) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = self.criterion(logits, y) + self.log('train_loss', loss) + self.log('train_acc', self.accuracy(logits, y)) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = self.criterion(logits, y) + self.log('val_loss', loss) + self.log('val_acc', self.accuracy(logits, y)) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = self.criterion(logits, y) + self.log('test_loss', loss) + self.log('test_acc', self.accuracy(logits, y)) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + +cifar_model = CIFAR10LightningTransformer() + +# train_ds = CIFAR10(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor()) +# train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE) +# val_ds = CIFAR10(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor()) +# val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE) +# test_ds = CIFAR10(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor()) +# test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE) + +# get the train data +train_ds = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor()) +# get test data +test_ds = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor()) + +# get the train loader +train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True) + +# get the test loader +test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False) + +# Initialize a trainer +trainer = Trainer( + accelerator="auto", + devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs + max_epochs=10, + callbacks=[TQDMProgressBar(refresh_rate=20)], + logger=CSVLogger(save_dir="logs/"), +) + +# Train the model ⚡ +# trainer.fit(cifar_model, train_loader) + +# Test the model +trainer.test(cifar_model, dataloaders=test_loader) \ No newline at end of file diff --git a/cifar_example/cifar10_lightning_ver_2.py b/cifar_example/cifar10_lightning_ver_2.py index c12ca99..c23b1d4 100644 --- a/cifar_example/cifar10_lightning_ver_2.py +++ b/cifar_example/cifar10_lightning_ver_2.py @@ -54,57 +54,57 @@ from cifar_example.cifar_transformer_modules.my_transformer_layer import MyTrans from cifar_example.cifar_transformer_modules.embedding import Embedding # create resnet model for cifar10 classification -class CIFAR10Model(pl.LightningModule): - def __init__(self): - super().__init__() - self.model = self.create_model() +# class CIFAR10Model(pl.LightningModule): +# def __init__(self): +# super().__init__() +# self.model = self.create_model() - def forward(self, x): - return self.model(x) +# def forward(self, x): +# return self.model(x) - def create_model(self): - model = resnet18(pretrained=False, num_classes=10) - model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - model.maxpool = nn.Identity() - return model - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - # log accuracy - self.log("train_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) - return loss - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - # log accuracy - self.log("val_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) - return loss - - def test_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - # log accuracy - self.log("test_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) - return loss - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.001) - - def accuracy(self, y_hat, y): - preds = torch.argmax(y_hat, dim=1) - return (preds == y).float().mean() +# def create_model(self): +# model = resnet18(pretrained=False, num_classes=10) +# model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) +# model.maxpool = nn.Identity() +# return model + +# def training_step(self, batch, batch_idx): +# x, y = batch +# y_hat = self(x) +# loss = F.cross_entropy(y_hat, y) +# self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) +# # log accuracy +# self.log("train_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) +# return loss + +# def validation_step(self, batch, batch_idx): +# x, y = batch +# y_hat = self(x) +# loss = F.cross_entropy(y_hat, y) +# self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) +# # log accuracy +# self.log("val_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) +# return loss + +# def test_step(self, batch, batch_idx): +# x, y = batch +# y_hat = self(x) +# loss = F.cross_entropy(y_hat, y) +# self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) +# # log accuracy +# self.log("test_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) +# return loss + +# def configure_optimizers(self): +# return torch.optim.Adam(self.parameters(), lr=0.001) + +# def accuracy(self, y_hat, y): +# preds = torch.argmax(y_hat, dim=1) +# return (preds == y).float().mean() # create resnet model for cifar10 classification -# class CIFAR10Model(pl.LightningModule): +class CIFAR10Model(pl.LightningModule): def __init__(self): super().__init__() self.embedding_size=64 diff --git a/dgcnn/attention_usage.ipynb b/dgcnn/attention_usage.ipynb new file mode 100644 index 0000000..4c7e038 --- /dev/null +++ b/dgcnn/attention_usage.ipynb @@ -0,0 +1,78 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "ename": "AssertionError", + "evalue": "embed_dim must be divisible by num_heads", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb Cell 1\u001b[0m in \u001b[0;36m2\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a>\u001b[0m \u001b[39m# show how to use self attention\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a>\u001b[0m x \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(\u001b[39m2\u001b[39m, \u001b[39m1024\u001b[39m, \u001b[39m3\u001b[39m)\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=26'>27</a>\u001b[0m self_attention \u001b[39m=\u001b[39m SelfAttention(\u001b[39m3\u001b[39;49m, \u001b[39m2\u001b[39;49m, \u001b[39m0.1\u001b[39;49m)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=27'>28</a>\u001b[0m out \u001b[39m=\u001b[39m self_attention(x)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=28'>29</a>\u001b[0m \u001b[39mprint\u001b[39m(out\u001b[39m.\u001b[39mshape)\n", + "\u001b[1;32m/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb Cell 1\u001b[0m in \u001b[0;36m1\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_heads \u001b[39m=\u001b[39m num_heads\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11'>12</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdropout \u001b[39m=\u001b[39m dropout\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Boracle_docker/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/attention_usage.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=12'>13</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mself_attention \u001b[39m=\u001b[39m MultiheadAttention(in_channels, num_heads\u001b[39m=\u001b[39;49mnum_heads, dropout\u001b[39m=\u001b[39;49mdropout)\n", + "File \u001b[0;32m~/.local/lib/python3.8/site-packages/torch/nn/modules/activation.py:960\u001b[0m, in \u001b[0;36mMultiheadAttention.__init__\u001b[0;34m(self, embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype)\u001b[0m\n\u001b[1;32m 958\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbatch_first \u001b[39m=\u001b[39m batch_first\n\u001b[1;32m 959\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mhead_dim \u001b[39m=\u001b[39m embed_dim \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m num_heads\n\u001b[0;32m--> 960\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mhead_dim \u001b[39m*\u001b[39m num_heads \u001b[39m==\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39membed_dim, \u001b[39m\"\u001b[39m\u001b[39membed_dim must be divisible by num_heads\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 962\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_qkv_same_embed_dim:\n\u001b[1;32m 963\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mq_proj_weight \u001b[39m=\u001b[39m Parameter(torch\u001b[39m.\u001b[39mempty((embed_dim, embed_dim), \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mfactory_kwargs))\n", + "\u001b[0;31mAssertionError\u001b[0m: embed_dim must be divisible by num_heads" + ] + } + ], + "source": [ + "# import pytorch attention\n", + "from torch.nn import MultiheadAttention\n", + "from torch import nn\n", + "import torch\n", + "\n", + "# implement self attention\n", + "class SelfAttention(nn.Module):\n", + " def __init__(self, in_channels, num_heads, dropout):\n", + " super(SelfAttention, self).__init__()\n", + " self.in_channels = in_channels\n", + " self.num_heads = num_heads\n", + " self.dropout = dropout\n", + " self.self_attention = MultiheadAttention(in_channels, num_heads=num_heads, dropout=dropout)\n", + "\n", + " def forward(self, x):\n", + " batch_size = x.size(0)\n", + " num_points = x.size(2)\n", + " x = x.view(batch_size, -1, num_points)\n", + " x = x.permute(1, 0, 2)\n", + " out, attn = self.self_attention(x, x, x)\n", + " out = out.permute(1, 0, 2)\n", + " out = out.view(batch_size, -1, num_points)\n", + " return out\n", + "\n", + "# show how to use self attention\n", + "x = torch.randn(2, 1024, 3)\n", + "self_attention = SelfAttention(3, 3, 0.1)\n", + "out = self_attention(x)\n", + "print(out.shape)\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dgcnn/get_size_of_dataset.py b/dgcnn/get_size_of_dataset.py index af7cf04..1ee6847 100644 --- a/dgcnn/get_size_of_dataset.py +++ b/dgcnn/get_size_of_dataset.py @@ -7,7 +7,7 @@ shapenet_data_train = ShapenetDataDgcnn( return_cls_label=True, small_data=False, small_data_size=1000, - just_four_classes=True, + num_classes=True, split='train', norm=True ) @@ -18,7 +18,7 @@ shapenet_data_test = ShapenetDataDgcnn( return_cls_label=True, small_data=False, small_data_size=1000, - just_four_classes=True, + num_classes=True, split='test', norm=True ) @@ -29,7 +29,7 @@ shapenet_data_val = ShapenetDataDgcnn( return_cls_label=True, small_data=False, small_data_size=1000, - just_four_classes=True, + num_classes=True, split='val', norm=True ) diff --git a/dgcnn/edge_conv_layer_run.ipynb b/dgcnn/jupyters/edge_conv_layer_run.ipynb similarity index 93% rename from dgcnn/edge_conv_layer_run.ipynb rename to dgcnn/jupyters/edge_conv_layer_run.ipynb index b27286d..6792c92 100644 --- a/dgcnn/edge_conv_layer_run.ipynb +++ b/dgcnn/jupyters/edge_conv_layer_run.ipynb @@ -207,6 +207,31 @@ "\n", "print(\"neighbors:\", neighbors.shape)" ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([32, 3])\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "# Assuming you have a tensor 'x' with shape (8, 32, 3)\n", + "x = torch.randn(8, 32, 3)\n", + "\n", + "# Remove the batch dimension\n", + "x_no_batch = x[1,:,:]\n", + "\n", + "print(x_no_batch.shape)" + ] } ], "metadata": { diff --git a/dgcnn/jupyters/model10_vis.ipynb b/dgcnn/jupyters/model10_vis.ipynb new file mode 100644 index 0000000..5d5b91c --- /dev/null +++ b/dgcnn/jupyters/model10_vis.ipynb @@ -0,0 +1,304 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.13.1+cu117\n" + ] + } + ], + "source": [ + "import os\n", + "import torch\n", + "os.environ['TORCH'] = torch.__version__\n", + "print(torch.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from glob import glob\n", + "from PIL import Image\n", + "from tqdm.auto import tqdm\n", + "\n", + "import wandb\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "import numpy as np\n", + "import networkx as nx\n", + "import matplotlib.pyplot as plt\n", + "from pyvis.network import Network\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "\n", + "import torch_geometric.transforms as T\n", + "from torch_geometric.datasets import ModelNet\n", + "from torch_geometric.loader import DataLoader\n", + "from torch_geometric.utils import to_networkx\n", + "from torch_geometric.nn import knn_graph, radius_graph" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "wandb version 0.14.2 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.13.10" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in <code>/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/jupyters/wandb/run-20230414_091510-h2iclcgp</code>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run <strong><a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp' target=\"_blank\">modelnet10/train/sampling-comparison</a></strong> to <a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at <a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud</a>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at <a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp</a>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "wandb_project = \"pyg-point-cloud\" #@param {\"type\": \"string\"}\n", + "wandb_run_name = \"modelnet10/train/sampling-comparison\" #@param {\"type\": \"string\"}\n", + "\n", + "wandb.init(project=wandb_project, entity=\"maciej-wielgosz-nibio\", name=wandb_run_name, job_type=\"eda\")\n", + "\n", + "# Set experiment configs to be synced with wandb\n", + "config = wandb.config\n", + "config.display_sample = 2048 #@param {type:\"slider\", min:256, max:4096, step:16}\n", + "config.modelnet_dataset_alias = \"ModelNet10\" #@param [\"ModelNet10\", \"ModelNet40\"] {type:\"raw\"}\n", + "\n", + "# Classes for ModelNet10 and ModelNet40\n", + "categories = sorted([\n", + " x.split(os.sep)[-2]\n", + " for x in glob(os.path.join(\n", + " config.modelnet_dataset_alias, \"raw\", '*', ''\n", + " ))\n", + "])\n", + "\n", + "\n", + "config.categories = categories" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "pre_transform = T.NormalizeScale()\n", + "transform = T.SamplePoints(config.display_sample)\n", + "train_dataset = ModelNet(\n", + " root=config.modelnet_dataset_alias,\n", + " name=config.modelnet_dataset_alias[-2:],\n", + " train=True,\n", + " transform=transform,\n", + " pre_transform=pre_transform\n", + ")\n", + "val_dataset = ModelNet(\n", + " root=config.modelnet_dataset_alias,\n", + " name=config.modelnet_dataset_alias[-2:],\n", + " train=False,\n", + " transform=transform,\n", + " pre_transform=pre_transform\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 20/20 [00:01<00:00, 11.47it/s]\n", + "100%|██████████| 100/100 [00:02<00:00, 36.10it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run <strong style=\"color:#cdcd00\">modelnet10/train/sampling-comparison</strong> at: <a href='https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/pyg-point-cloud/runs/h2iclcgp</a><br/>Synced 5 W&B file(s), 3 media file(s), 103 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: <code>./wandb/run-20230414_091510-h2iclcgp/logs</code>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "table = wandb.Table(columns=[\"Model\", \"Class\", \"Split\"])\n", + "category_dict = {key: 0 for key in config.categories}\n", + "for idx in tqdm(range(len(train_dataset[:20]))):\n", + " point_cloud = wandb.Object3D(train_dataset[idx].pos.numpy())\n", + " category = config.categories[int(train_dataset[idx].y.item())]\n", + " category_dict[category] += 1\n", + " table.add_data(\n", + " point_cloud,\n", + " category,\n", + " \"Train\"\n", + " )\n", + "\n", + "data = [[key, category_dict[key]] for key in config.categories]\n", + "wandb.log({\n", + " f\"{config.modelnet_dataset_alias} Class-Frequency Distribution\" : wandb.plot.bar(\n", + " wandb.Table(data=data, columns = [\"Class\", \"Frequency\"]),\n", + " \"Class\", \"Frequency\",\n", + " title=f\"{config.modelnet_dataset_alias} Class-Frequency Distribution\"\n", + " )\n", + "})\n", + "\n", + "table = wandb.Table(columns=[\"Model\", \"Class\", \"Split\"])\n", + "category_dict = {key: 0 for key in config.categories}\n", + "for idx in tqdm(range(len(val_dataset[:100]))):\n", + " point_cloud = wandb.Object3D(val_dataset[idx].pos.numpy())\n", + " category = config.categories[int(val_dataset[idx].y.item())]\n", + " category_dict[category] += 1\n", + " table.add_data(\n", + " point_cloud,\n", + " category,\n", + " \"Test\"\n", + " )\n", + "wandb.log({config.modelnet_dataset_alias: table})\n", + "\n", + "data = [[key, category_dict[key]] for key in config.categories]\n", + "wandb.log({\n", + " f\"{config.modelnet_dataset_alias} Class-Frequency Distribution\" : wandb.plot.bar(\n", + " wandb.Table(data=data, columns = [\"Class\", \"Frequency\"]),\n", + " \"Class\", \"Frequency\",\n", + " title=f\"{config.modelnet_dataset_alias} Class-Frequency Distribution\"\n", + " )\n", + "})\n", + "\n", + "wandb.finish()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dgcnn/jupyters/my_shapenet_vis.ipynb b/dgcnn/jupyters/my_shapenet_vis.ipynb new file mode 100644 index 0000000..bbd92c6 --- /dev/null +++ b/dgcnn/jupyters/my_shapenet_vis.ipynb @@ -0,0 +1,414 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "import random\n", + "random.seed(10)\n", + "\n", + "\n", + "wandb_project = \"my_shapenet_vis1\" \n", + "wandb_run_name = \"my_shapenet_vis_showcase\" \n", + "WANDB_NOTEBOOK_NAME = \"my_shapenet_vis.ipynb\"\n", + "\n", + "wandb.init(project=wandb_project, entity=\"maciej-wielgosz-nibio\", name=wandb_run_name, job_type=\"eda\")\n", + "\n", + "# load shape net data\n", + "import sys\n", + "sys.path.append('/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn')\n", + "from shapenet_data_dgcnn import ShapenetDataDgcnn\n", + "shapenet_data = ShapenetDataDgcnn(\n", + " root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',\n", + " npoints=512,\n", + " return_cls_label=True,\n", + " small_data=True,\n", + " small_data_size=10,\n", + " num_classes=1,\n", + " data_augmentation=False,\n", + " split='train',\n", + " norm=True\n", + " )\n", + "\n", + "# get first data point\n", + "data = shapenet_data[1]\n", + "data[0]\n", + "\n", + "# print(data[1])\n", + "\n", + "# find how many different values are in data[1]\n", + "import numpy as np\n", + "uv = np.unique(data[1])\n", + "\n", + "# generte random RGB colors for each class\n", + "import random\n", + "colors = []\n", + "for i in range(50):\n", + " colors.append([random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)])\n", + "\n", + "colors_per_point = [colors[data[1][i]] for i in range(len(data[1]))]\n", + "\n", + "# create a point cloud\n", + "points_rgb = np.array([[p[0], p[1], p[2], c[0], c[1], c[2]] for p, c in zip(data[0], colors_per_point)])\n", + "\n", + "point_cloud = wandb.Object3D(\n", + " {\n", + " \"type\": \"lidar/beta\",\n", + " \"points\": points_rgb\n", + " }\n", + ")\n", + "\n", + "# show point cloud in wandb\n", + "wandb.log({\"point_cloud\": point_cloud})\n", + "\n", + "\n", + "# create a point cloud points_rgb_0 by choosing only points with class 0\n", + "points_rgb_0 = points_rgb[data[1] == 0]\n", + "\n", + "print(points_rgb_0.shape)\n", + "\n", + "point_cloud_0 = wandb.Object3D(\n", + " {\n", + " \"type\": \"lidar/beta\",\n", + " \"points\": points_rgb_0\n", + " }\n", + ")\n", + "\n", + "# create a point cloud points_rgb_1 by choosing only points with class 1\n", + "points_rgb_1 = points_rgb[data[1] == 1]\n", + "\n", + "print(points_rgb_1.shape)\n", + "\n", + "point_cloud_1 = wandb.Object3D(\n", + " {\n", + " \n", + " \"type\": \"lidar/beta\",\n", + " \"points\": points_rgb_1\n", + " }\n", + ")\n", + "\n", + "\n", + "# show point cloud in wandb\n", + "wandb.log({\"point_cloud_0\": point_cloud_0})\n", + "wandb.log({\"point_cloud_1\": point_cloud_1})\n", + "\n", + "# get the histogram for the data[1] (class labels)\n", + "labels_hist = {}\n", + "for label in data[1]:\n", + " if label in labels_hist:\n", + " labels_hist[label] += 1\n", + " else:\n", + " labels_hist[label] = 1\n", + "\n", + "# create a table for the histogram\n", + "table = wandb.Table(columns=[\"Class\", \"frequency\"])\n", + "\n", + "# write the class histogram to wandb\n", + "for class_name, count in labels_hist.items():\n", + " table.add_data(class_name, count)\n", + "\n", + "wandb.log({\"class freq\": wandb.plot.bar(table, \"Class\", \"frequency\")})\n", + "\n", + "wandb.finish()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Finishing last run (ID:vrw80dza) before initializing another..." + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run <strong style=\"color:#cdcd00\">my_shapenet_vis_showcase</strong> at: <a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/vrw80dza' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/vrw80dza</a><br/>Synced 5 W&B file(s), 1 media file(s), 31 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: <code>./wandb/run-20230417_115825-vrw80dza/logs</code>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Successfully finished last run (ID:vrw80dza). Initializing new run:<br/>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.14.2" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in <code>/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn/jupyters/wandb/run-20230417_120056-f4w7yi7s</code>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run <strong><a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/f4w7yi7s' target=\"_blank\">my_shapenet_vis_showcase</a></strong> to <a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at <a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1</a>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at <a href='https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/f4w7yi7s' target=\"_blank\">https://wandb.ai/maciej-wielgosz-nibio/my_shapenet_vis1/runs/f4w7yi7s</a>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "({0: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d575ea30>,\n", + " 1: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd857cd0>,\n", + " 2: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd91c8e0>,\n", + " 3: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd91caf0>,\n", + " 4: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57509a0>,\n", + " 5: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d578bdf0>,\n", + " 6: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d36b7d30>,\n", + " 7: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57fdca0>,\n", + " 8: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57d6a60>,\n", + " 9: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd7e4250>},\n", + " {0: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd7e4490>,\n", + " 1: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd7dbb50>,\n", + " 2: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57f70a0>,\n", + " 3: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5789340>,\n", + " 4: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d57894c0>,\n", + " 5: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5777310>,\n", + " 6: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d58012b0>,\n", + " 7: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d575e910>,\n", + " 8: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5790040>,\n", + " 9: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd84a5e0>},\n", + " {0: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd84a700>,\n", + " 1: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd857eb0>,\n", + " 2: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1dd857e80>,\n", + " 3: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750d60>,\n", + " 4: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750f40>,\n", + " 5: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750820>,\n", + " 6: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750c70>,\n", + " 7: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d5750cd0>,\n", + " 8: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d578bfa0>,\n", + " 9: <wandb.sdk.data_types.object_3d.Object3D at 0x7fd1d578b700>})" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import wandb\n", + "import random\n", + "import numpy as np\n", + "random.seed(10)\n", + "\n", + "\n", + "wandb_project = \"my_shapenet_vis1\" \n", + "wandb_run_name = \"my_shapenet_vis_showcase\" \n", + "WANDB_NOTEBOOK_NAME = \"my_shapenet_vis.ipynb\"\n", + "\n", + "run = wandb.init(project=wandb_project, entity=\"maciej-wielgosz-nibio\", name=wandb_run_name, job_type=\"eda\")\n", + "\n", + "# load shape net data\n", + "import sys\n", + "sys.path.append('/home/nibio/mutable-outside-world/code/oracle_gpu_runs/dgcnn')\n", + "from shapenet_data_dgcnn import ShapenetDataDgcnn\n", + "shapenet_data = ShapenetDataDgcnn(\n", + " root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',\n", + " npoints=512,\n", + " return_cls_label=True,\n", + " small_data=True,\n", + " small_data_size=10,\n", + " num_classes=1,\n", + " data_augmentation=False,\n", + " split='train',\n", + " norm=True\n", + " )\n", + "\n", + "# create a function which will create a point cloud for each class\n", + "def create_point_clouds(data, pred):\n", + " data_point_clouds = {}\n", + " pred_point_clouds = {}\n", + " diff_point_clouds = {}\n", + "\n", + " colors = []\n", + " for i in range(50):\n", + " colors.append([random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)])\n", + "\n", + " for i in range(len(data)):\n", + " colors_per_point = [colors[data[i][1][j]] for j in range(len(data[i][1]))]\n", + " points_rgb = np.array([[p[0], p[1], p[2], c[0], c[1], c[2]] for p, c in zip(data[i][0], colors_per_point)])\n", + " data_point_clouds[i] = wandb.Object3D(\n", + " {\n", + " \"type\": \"lidar/beta\",\n", + " \"points\": points_rgb\n", + " }\n", + " )\n", + "\n", + " for i in range(len(pred)):\n", + " colors_per_point = [colors[pred[i][1][j]] for j in range(len(pred[i][1]))]\n", + " points_rgb = np.array([[p[0], p[1], p[2], c[0], c[1], c[2]] for p, c in zip(pred[i][0], colors_per_point)])\n", + " pred_point_clouds[i] = wandb.Object3D(\n", + " {\n", + " \"type\": \"lidar/beta\",\n", + " \"points\": points_rgb\n", + " }\n", + " )\n", + " \n", + " for i in range(len(data)):\n", + " diff_point_clouds[i] = data[i][0] - pred[i][0] + 1\n", + " diff_point_clouds[i] = wandb.Object3D(\n", + " {\n", + " \"type\": \"lidar/beta\",\n", + " \"points\": diff_point_clouds[i],\n", + " \n", + " }\n", + " )\n", + " \n", + " # wandb.log({\"point_cloud_\" + str(i): point_clouds[i]})\n", + " # table.add_data(point_clouds[i], point_clouds[i])\n", + " table_data = [[data_point_clouds[i], pred_point_clouds[i], diff_point_clouds[i]] for i in range(len(data_point_clouds))]\n", + " table = wandb.Table(data=table_data, columns=[\"gt\", \"pred\", \"diff\"])\n", + "\n", + " # show the table in wandb\n", + " run.log({\"point_clouds table \": table})\n", + "\n", + " return data_point_clouds, pred_point_clouds, diff_point_clouds\n", + "\n", + "\n", + "# create a function which will create a histogram for each class\n", + "def create_histograms(data):\n", + " labels_hist = {}\n", + " for i in range(len(data)):\n", + " labels_hist[i] = {}\n", + " for label in data[i][1]:\n", + " if label in labels_hist[i]:\n", + " labels_hist[i][label] += 1\n", + " else:\n", + " labels_hist[i][label] = 1\n", + " \n", + "\n", + " # create a table for the histogram\n", + " table = wandb.Table(columns=[\"Class\", \"frequency\"])\n", + "\n", + " # write the class histogram to wandb\n", + " for class_name, count in labels_hist.items():\n", + " table.add_data(class_name, count)\n", + "\n", + " wandb.log({\"class freq\": wandb.plot.bar(table, \"Class\", \"frequency\")})\n", + " \n", + "\n", + "# use the functions to create point clouds and histograms\n", + "create_point_clouds(shapenet_data, shapenet_data)\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dgcnn/transform_net_run.ipynb b/dgcnn/jupyters/transform_net_run.ipynb similarity index 65% rename from dgcnn/transform_net_run.ipynb rename to dgcnn/jupyters/transform_net_run.ipynb index fa3166b..213d882 100644 --- a/dgcnn/transform_net_run.ipynb +++ b/dgcnn/jupyters/transform_net_run.ipynb @@ -81,6 +81,56 @@ "# Apply the transformation matrix to the input point cloud\n", "input_tensor_transformed = torch.bmm(transform_matrix, input_tensor)\n" ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ground truth: tensor([[1, 0],\n", + " [0, 1],\n", + " [1, 0],\n", + " [0, 1],\n", + " [1, 0],\n", + " [0, 1]])\n", + "Predicted: tensor([[1, 0],\n", + " [0, 1],\n", + " [0, 1],\n", + " [0, 1],\n", + " [1, 0],\n", + " [0, 1]])\n", + "Jaccard index: tensor(0.7500)\n" + ] + } + ], + "source": [ + "import torch\n", + "import torchmetrics\n", + "\n", + "# Define the Jaccard index (IoU) metric\n", + "jaccard_index = torchmetrics.JaccardIndex(num_classes=2, average=\"macro\", task=\"binary\")\n", + "\n", + "# Example ground truth and predicted labels\n", + "ground_truth = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.long)\n", + "predicted = torch.tensor([0, 1, 1, 1, 0, 1], dtype=torch.long)\n", + "\n", + "# One-hot encode the ground truth and predicted labels\n", + "ground_truth_one_hot = torch.nn.functional.one_hot(ground_truth, num_classes=2)\n", + "predicted_one_hot = torch.nn.functional.one_hot(predicted, num_classes=2)\n", + "\n", + "print(\"Ground truth: \", ground_truth_one_hot)\n", + "print(\"Predicted: \", predicted_one_hot)\n", + "\n", + "\n", + "# Compute the Jaccard index\n", + "# jaccard_index = jaccard_index(predicted_one_hot, ground_truth_one_hot)\n", + "jaccard_index = jaccard_index(predicted, ground_truth)\n", + "print(\"Jaccard index: \", jaccard_index)" + ] } ], "metadata": { diff --git a/dgcnn/my_models/model_shape_net.py b/dgcnn/my_models/model_shape_net.py index 3593570..fc275ac 100644 --- a/dgcnn/my_models/model_shape_net.py +++ b/dgcnn/my_models/model_shape_net.py @@ -7,7 +7,7 @@ from my_models.edge_conv_new import EdgeConvNew class DgcnShapeNet(nn.Module): - def __init__(self, seg_num_all): + def __init__(self, seg_num_all, num_classes): super(DgcnShapeNet, self).__init__() self.seg_num_all = seg_num_all self.transform_net = Transform_Net() @@ -25,7 +25,7 @@ class DgcnShapeNet(nn.Module): self.conv6 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False), self.bn6, nn.LeakyReLU(negative_slope=0.2)) - self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False), + self.conv7 = nn.Sequential(nn.Conv1d(num_classes, 64, kernel_size=1, bias=False), self.bn7, nn.LeakyReLU(negative_slope=0.2)) self.conv8 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False), diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index abf9ba9..d518a63 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object): small_data=False, small_data_size=10, return_cls_label=False, - just_four_classes=False, + num_classes=1, # None - all classes (50), 1 - one class, 2 - two classes, max 4 norm=False, augmnetation=False, data_augmentation=False @@ -29,7 +29,7 @@ class ShapenetDataDgcnn(object): self.small_data = small_data self.small_data_size = small_data_size self.return_cls_label = return_cls_label - self.just_four_classes = just_four_classes + self.num_classes = num_classes self.norm = norm self.augmnetation = augmnetation self.data_augmentation = data_augmentation @@ -90,12 +90,28 @@ class ShapenetDataDgcnn(object): # get one class of data # get the the number of the class airplane - if self.just_four_classes: - out_data = [x for x in out_data if x.split('/')[-2] in [ + if self.num_classes is not None: + if self.num_classes == 1: + out_data = [x for x in out_data if x.split('/')[-2] in [ + self.cat['Airplane'] + ]] + elif self.num_classes == 2: + out_data = [x for x in out_data if x.split('/')[-2] in [ + self.cat['Airplane'], + self.cat['Lamp'] + ]] + elif self.num_classes == 3: + out_data = [x for x in out_data if x.split('/')[-2] in [ + self.cat['Airplane'], + self.cat['Lamp'], + self.cat['Car'] + ]] + elif self.num_classes == 4: + out_data = [x for x in out_data if x.split('/')[-2] in [ self.cat['Airplane'], self.cat['Lamp'], - self.cat['Chair'], - self.cat['Table'], + self.cat['Car'], + self.cat['Chair'] ]] return out_data @@ -207,8 +223,6 @@ class ShapenetDataDgcnn(object): point_set = self.rotate_pointcloud(point_set) point_set = self.translate_pointcloud(point_set) - - choice = np.random.choice(len(point_set), self.npoints, replace=True) point_set = point_set[choice, :] @@ -227,16 +241,18 @@ class ShapenetDataDgcnn(object): class_name = self.val_data_file[index].split('/')[-2] # apply the mapper - if self.just_four_classes: - class_name = self.class_mapper_4_classes(class_name) - else: - class_name = self.class_mapper(class_name) + # if self.num_classes: + # class_name = self.class_mapper_4_classes(class_name) + # else: + # class_name = self.class_mapper(class_name) + class_name = self.class_mapper(class_name) + # convert the class name to a number class_name = np.array(class_name, dtype=np.int64) # map to tensor - class_name = torch.from_numpy(class_name) + # class_name = torch.from_numpy(class_name) if self.return_cls_label: return point_set, labels, class_name -- GitLab