Skip to content
Snippets Groups Projects
Commit 0e069d09 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

dedicated model for classification

parent d1edd85f
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@ import yaml
from shapenet_data_dgcnn import ShapenetDataDgcnn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from model import DGCNN
from model_class import DgcnnClass
from torchmetrics import Accuracy, Precision, Recall
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.strategies import DDPStrategy
......@@ -41,7 +41,7 @@ with open('config.yaml', 'r') as f:
class DGCNNLightning(pl.LightningModule):
def __init__(self, num_classes):
super().__init__()
self.dgcnn = DGCNN(num_classes)
self.dgcnn = DgcnnClass(num_classes)
# train define metrics
self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
self.train_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro')
......@@ -124,11 +124,11 @@ class DGCNNLightning(pl.LightningModule):
self.test_recall.reset()
def configure_optimizers(self):
# optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
# return optimizer
optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"}
optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
return optimizer
# optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9)
# scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)
# return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"}
# get train data
......@@ -152,7 +152,7 @@ shapenet_data_val = ShapenetDataDgcnn(
small_data=config['data']['small_data'],
small_data_size=config['data']['small_data_size'],
just_one_class=config['data']['just_one_class'],
split='val',
split='train',
norm=config['data']['norm']
)
......
from tqdm import tqdm
from shapenet_data_dgcnn import ShapenetDataDgcnn
shapenet_data = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=128,
return_cls_label=True,
small_data=False,
small_data_size=1000,
just_one_class=False,
split='train',
norm=True
)
# read the data one by one and check if exists
for i in tqdm(range(len(shapenet_data))):
data = shapenet_data[i]
if data[0].shape[0] != 128:
print(f"Data is None: {i}")
\ No newline at end of file
from tqdm import tqdm
from shapenet_data_dgcnn import ShapenetDataDgcnn
shapenet_data_train = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=128,
return_cls_label=True,
small_data=False,
small_data_size=1000,
just_one_class=True,
split='train',
norm=True
)
shapenet_data_test = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=128,
return_cls_label=True,
small_data=False,
small_data_size=1000,
just_one_class=True,
split='test',
norm=True
)
shapenet_data_val = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=128,
return_cls_label=True,
small_data=False,
small_data_size=1000,
just_one_class=True,
split='val',
norm=True
)
# print the length of the data
print(f"Train: {len(shapenet_data_train)}")
print(f"Test: {len(shapenet_data_test)}")
print(f"Val: {len(shapenet_data_val)}")
......@@ -100,6 +100,7 @@ class DGCNN(nn.Module):
self.bn5,
nn.LeakyReLU(negative_slope=0.2))
self.linear1 = nn.Linear(2048, 512, bias=False)
self.dropout = nn.Dropout(p=0.5)
self.fc = nn.Sequential(
nn.Linear(512, 256),
......@@ -143,5 +144,6 @@ class DGCNN(nn.Module):
x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2)
# x9 = x9.max(dim=1, keepdim=False)[0]
x10 = self.linear1(x8)
x10 = self.dropout(x10)
x11 = self.fc(x10)
return x11
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
# TODO: There are mode conv layers there
class EdgeConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(EdgeConv, self).__init__()
self.in_channels = in_channels
self.conv = nn.Sequential(
nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, k=20):
#batch_size, num_points, in_channels
batch_size, num_points, feature_dim = x.shape
x = x.view(batch_size, num_points,feature_dim )
knn_indices = self.knn(x, k)
knn_gathered = self.gather_neighbors(x, knn_indices)
edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
edge_features = edge_features.view(batch_size, 2*feature_dim, num_points, k)
return self.conv(edge_features).transpose(2, 1)
@staticmethod
def knn(x, k):
"""Find the indices of the k nearest neighbors for each point in the input tensor."""
batch_size, num_points, _ = x.shape
x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
distances = torch.norm(x_expanded - x_transposed, dim=-1)
_, indices = distances.topk(k=k, dim=-1, largest=False)
return indices
def gather_neighbors(self, x, knn_indices):
batch_size, num_points, _ = x.shape
_, _, k = knn_indices.shape
x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, self.in_channels))
return neighbors
class Transform_Net(nn.Module):
def __init__(self, k=3):
super(Transform_Net, self).__init__()
self.k = k
self.conv1 = nn.Conv1d(k, 64, 1)
self.conv2 = nn.Conv1d(64, 128, 1)
self.conv3 = nn.Conv1d(128, 256, 1)
self.fc3 = nn.Linear(256, k*k)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(256)
self.bn4 = nn.BatchNorm1d(128)
self.bn5 = nn.BatchNorm1d(64)
self.dropout_0 = nn.Dropout(p=0.5)
self.dropout_1 = nn.Dropout(p=0.5)
self.dropout_2 = nn.Dropout(p=0.5)
def forward(self, x):
# Input shape: (batch_size, k, num_points)
x = x.transpose(2, 1)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 256)
x = self.fc3(x)
x = self.dropout_2(x)
identity = torch.eye(self.k, dtype=x.dtype, device=x.device)
transform = x.view(-1, self.k, self.k) + identity
transform = transform.transpose(2, 1)
return transform
class DgcnnClass(nn.Module):
def __init__(self, num_classes):
super(DgcnnClass, self).__init__()
self.transform_net = Transform_Net()
self.edge_conv1 = EdgeConv(3, 64)
self.edge_conv2 = EdgeConv(64, 128)
self.bn5 = nn.BatchNorm1d(256)
self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False),
self.bn5,
nn.LeakyReLU(negative_slope=0.2))
self.linear1 = nn.Linear(512, 256, bias=False)
self.dropout = nn.Dropout(p=0.5)
self.fc = nn.Sequential(
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(128, num_classes),
)
def forward(self, x):
# Apply Transform_Net on input point cloud
batch_size = x.size(0)
trans_matrix = self.transform_net(x)
x = torch.bmm(x, trans_matrix)
x1 = self.edge_conv1(x)
x1 = x1.max(dim=-1, keepdim=False)[0]
# print("x1 shape: ", x1.shape)
x2 = self.edge_conv2(x1)
x2 = x2.max(dim=-1, keepdim=False)[0]
# print("x2 shape: ", x2.shape)
x5 = torch.cat((x1, x2), dim=2) # (batch_size, 64+64+128+256, num_points)
x5 = x5.transpose(2, 1) # (batch_size, num_points, 64+64+128+256)
# print("x5 shape: ", x5.shape)
x_conv = self.conv5(x5) # (batch_size, 1024, num_points)
# print("x_conv shape: ", x_conv.shape)
x6 = F.adaptive_max_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
# print("x6 shape: ", x6.shape)
x7 = F.adaptive_avg_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
# print("x7 shape: ", x7.shape)
x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2)
# x9 = x8.max(dim=1, keepdim=False)[0]
x10 = self.linear1(x8)
x10 = self.dropout(x10)
x11 = self.fc(x10)
return x11
\ No newline at end of file
......@@ -90,7 +90,12 @@ class ShapenetDataDgcnn(object):
# get the the number of the class airplane
if self.just_one_class:
data = [x for x in data if x.split('/')[-2] == self.cat['Airplane']]
data = [x for x in data if x.split('/')[-2] in [
self.cat['Airplane'],
self.cat['Lamp'],
self.cat['Chair'],
self.cat['Table'],
]]
return data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment