Skip to content
Snippets Groups Projects
Commit b4c9b22d authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

first steps in dgcnn

parent 5a2f3f6a
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from model import DGCNN
def main():
num_classes = 10
dgcnn = DGCNN(num_classes)
dgcnn.eval()
# simple test
input_tensor = torch.randn(1, 128, 3)
print(input_tensor.shape)
out = dgcnn(input_tensor)
print(out.shape)
if __name__ == '__main__':
main()
%% Cell type:code id: tags:
``` python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class EdgeConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(EdgeConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, k=20):
#batch_size, num_points, in_channels
batch_size, num_points, feature_dim = x.shape
x = x.view(batch_size, num_points,feature_dim )
knn_indices = self.knn(x, k)
knn_gathered = self.gather_neighbors(x, knn_indices)
edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
edge_features = edge_features.view(batch_size, feature_dim, 2*num_points, k)
return self.conv(edge_features).transpose(2, 1)
@staticmethod
def knn(x, k):
"""Find the indices of the k nearest neighbors for each point in the input tensor."""
batch_size, num_points, _ = x.shape
x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
distances = torch.norm(x_expanded - x_transposed, dim=-1)
_, indices = distances.topk(k=k, dim=-1, largest=False)
return indices
@staticmethod
def gather_neighbors(x, knn_indices):
batch_size, num_points, _ = x.shape
_, _, k = knn_indices.shape
x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, 3))
return neighbors
# Generate a sample input vector
batch_size = 1
in_channels = 3
num_points = 1024
point_cloud = torch.randn(batch_size, in_channels, num_points)
# Create an EdgeConv layer
out_channels = 32 # You can choose the desired number of output channels
edge_conv_layer = EdgeConv(in_channels, out_channels)
edge_conv_layer.eval()
# Pass the sample input vector through the EdgeConv layer
k = 10
output = edge_conv_layer(point_cloud.view(batch_size, num_points, in_channels), k)
print("Output shape:", output.shape)
```
%% Output
Output shape: torch.Size([1, 2048, 32, 10])
%% Cell type:code id: tags:
``` python
def knn(x, k):
inner = -2 * torch.matmul(x.transpose(2, 1), x)
xx = torch.sum(x ** 2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1)
knn_indices = pairwise_distance.topk(k=k, dim=-1)[1]
return knn_indices
import torch
# Set sample input parameters
batch_size = 2
num_points = 1024
feature_dim = 3
k = 20
# Create a random input tensor
x = torch.randn(batch_size, num_points, feature_dim)
print("Sample input vector x:")
print(x.shape)
# Compute the k nearest neighbors
knn_indices = knn(x, k)
print("k nearest neighbors indices:")
print(knn_indices.shape)
```
%% Cell type:code id: tags:
``` python
import torch
def knn(x, k):
"""Find the indices of the k nearest neighbors for each point in the input tensor."""
batch_size, num_points, _ = x.shape
x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
distances = torch.norm(x_expanded - x_transposed, dim=-1)
_, indices = distances.topk(k=k, dim=-1, largest=False)
return indices
# Set sample input parameters
batch_size = 1
num_points = 1024
feature_dim = 3
k = 20
# Create a random input tensor
x = torch.randn(batch_size, num_points, feature_dim)
print("Sample input vector x:")
print(x.shape)
# Compute the k nearest neighbors
knn_indices = knn(x, k)
print("k nearest neighbors indices:")
print(knn_indices.shape)
```
%% Output
Sample input vector x:
torch.Size([1, 1024, 3])
k nearest neighbors indices:
torch.Size([1, 1024, 20])
%% Cell type:code id: tags:
``` python
x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
print("x_expanded:", x_expanded.shape)
print(x_expanded[0, 0, 0, :])
print(x[0, 0, :])
```
%% Output
x_expanded: torch.Size([1, 1024, 1024, 3])
tensor([-0.7148, -0.1634, -1.3527])
tensor([-0.7148, -0.1634, -1.3527])
%% Cell type:code id: tags:
``` python
print('knn_indices:', knn_indices.shape)
# Gather the neighbors
neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, 3))
print("neighbors:", neighbors.shape)
```
%% Output
knn_indices: torch.Size([1, 1024, 20])
neighbors: torch.Size([1, 1024, 20, 3])
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
#TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
class EdgeConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(EdgeConv, self).__init__()
self.in_channels = in_channels
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, k=20):
batch_size, num_points, feature_dim = x.shape
x = x.view(batch_size, num_points,feature_dim )
knn_indices = self.knn(x, k)
knn_gathered = self.gather_neighbors(x, knn_indices)
edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
edge_features = edge_features.view(batch_size, feature_dim, 2*num_points, k)
return self.conv(edge_features).transpose(2, 1)
@staticmethod
def knn(x, k):
"""Find the indices of the k nearest neighbors for each point in the input tensor."""
batch_size, num_points, _ = x.shape
x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
distances = torch.norm(x_expanded - x_transposed, dim=-1)
_, indices = distances.topk(k=k, dim=-1, largest=False)
return indices
def gather_neighbors(self, x, knn_indices):
batch_size, num_points, _ = x.shape
_, _, k = knn_indices.shape
x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, self.in_channels))
return neighbors
class Transform_Net(nn.Module):
def __init__(self, k=3):
super(Transform_Net, self).__init__()
self.k = k
self.conv1 = nn.Conv1d(k, 64, 1)
self.conv2 = nn.Conv1d(64, 128, 1)
self.conv3 = nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k*k)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
def forward(self, x):
# Input shape: (batch_size, k, num_points)
x = x.transpose(2, 1)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)
identity = torch.eye(self.k, dtype=x.dtype, device=x.device)
transform = x.view(-1, self.k, self.k) + identity
transform = transform.transpose(2, 1)
return transform
class DGCNN(nn.Module):
def __init__(self, num_classes):
super(DGCNN, self).__init__()
self.transform_net = Transform_Net()
self.edge_conv1 = EdgeConv(3, 64)
self.edge_conv2 = EdgeConv(64, 128)
self.edge_conv3 = EdgeConv(128, 256)
self.edge_conv4 = EdgeConv(256, 512)
self.fc = nn.Sequential(
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(128, num_classes),
)
def forward(self, x):
# Apply Transform_Net on input point cloud
trans_matrix = self.transform_net(x)
x = torch.bmm(x, trans_matrix)
x1 = self.edge_conv1(x)
x1 = x1.max(dim=-1, keepdim=False)[0]
x2 = self.edge_conv2(x1)
x2 = x2.max(dim=-1, keepdim=False)[0]
x3 = self.edge_conv3(x2)
x3 = x3.max(dim=-1, keepdim=False)[0]
x4 = self.edge_conv4(x3)
x4 = x4.max(dim=-1, keepdim=False)[0]
x = torch.max(x4, dim=1, keepdim=True)[0]
x = self.fc(x.squeeze(1))
return x
\ No newline at end of file
import json
import os
import numpy as np
class ShapenetData(object):
"""
The is the data loader for the ShapeNet dataset. Only for data segmentation, not for classification.
"""
def __init__(self,
root,
npoints=1200,
split='train',
small_data=False,
small_data_size=10,
just_one_class=False,
norm=True,
data_augmentation=False
) -> None:
self.npoints = npoints
self.root = root
self.split = split
self.small_data = small_data
self.small_data_size = small_data_size
self.just_one_class = just_one_class
self.norm = norm
self.data_augmentation = data_augmentation
# data operations
self.train_data_suffled_list_file = os.path.join(
self.root,
'raw',
'train_test_split',
'shuffled_train_file_list.json')
self.test_data_suffled_list_file = os.path.join(
self.root,
'raw',
'train_test_split',
'shuffled_test_file_list.json')
self.val_data_file = os.path.join(
self.root,
'raw',
'train_test_split',
'shuffled_val_file_list.json')
self.catfile = os.path.join(self.root,
'raw',
'synsetoffset2category.txt')
self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
self.cat = {}
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
self.train_file_list = self.get_list_shuffled_data(self.root, self.train_data_suffled_list_file)
self.test_file_list = self.get_list_shuffled_data(self.root, self.test_data_suffled_list_file)
self.val_data_file = self.get_list_shuffled_data(self.root, self.val_data_file)
# take the first 10 data for training
if self.small_data:
self.train_file_list = self.train_file_list[:self.small_data_size]
self.test_file_list = self.test_file_list[:self.small_data_size]
self.val_data_file = self.val_data_file[:self.small_data_size]
def get_list_shuffled_data(self,root, json_file):
# read the json file and return the list of data
with open(json_file, 'r') as f:
data = json.load(f)
for i in range(len(data)):
data[i] = os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt'))
# get one class of data
# get the the number of the class airplane
if self.just_one_class:
data = [x for x in data if x.split('/')[-2] == self.cat['Airplane']]
return data
def get_seg_classes(self, cat):
return self.seg_classes[cat]
def get_class_names(self):
return list(self.cat.values())
def get_all_names_of_classes(self):
return list(self.cat.keys())
def normalize(self, pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
# TODO: add the selection for a given class
def __getitem__(self, index):
if self.split == 'train':
point_set = np.loadtxt(self.train_file_list[index]).astype(np.float32)
elif self.split == 'test':
point_set = np.loadtxt(self.test_file_list[index]).astype(np.float32)
elif self.split == 'val':
point_set = np.loadtxt(self.val_data_file[index]).astype(np.float32)
# get just the points
point_set = point_set[:, 0:3]
# normalize the points
if self.norm:
point_set = self.normalize(point_set)
# choice = np.random.choice(len(point_set), self.npoints, replace=True)
# chose the first npoints
choice = np.arange(self.npoints)
point_set = point_set[choice, :]
point_set = point_set.astype(np.float32)
return point_set
def __len__(self):
if self.split == 'train':
return len(self.train_file_list)
elif self.split == 'test':
return len(self.test_file_list)
elif self.split == 'val':
return len(self.val_data_file)
if __name__ == "__main__":
shapenet_data = ShapenetData(root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', split='train')
# print(shapenet_data.train_file_list)
print(shapenet_data.get_seg_classes('Car'))
print(shapenet_data.get_class_names())
print(shapenet_data.get_all_names_of_classes())
# get the first point cloud
point_cloud = shapenet_data[0]
print(point_cloud[:10])
\ No newline at end of file
%% Cell type:code id: tags:
``` python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class Transform_Net(nn.Module):
def __init__(self, k=3):
super(Transform_Net, self).__init__()
self.k = k
self.conv1 = nn.Conv1d(k, 64, 1)
self.conv2 = nn.Conv1d(64, 128, 1)
self.conv3 = nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k*k)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
def forward(self, x):
# Input shape: (batch_size, k, num_points)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)
identity = torch.eye(self.k, dtype=x.dtype, device=x.device)
transform = x.view(-1, self.k, self.k) + identity
return transform
# Create a random input tensor
batch_size = 1
num_points = 1024
k = 3
# Generate a random point cloud of shape (batch_size, k, num_points)
input_tensor = torch.randn(batch_size, k, num_points)
# Initialize the Transform_Net module
transform_net = Transform_Net(k)
transform_net.eval()
# Forward pass the input tensor through the Transform_Net module
transform_matrix = transform_net(input_tensor)
# Check the output tensor shape
print("Transform matrix shape: ", transform_matrix.shape)
# Apply the transformation matrix to the input point cloud
input_tensor_transformed = torch.bmm(transform_matrix, input_tensor)
```
%% Output
Transform matrix shape: torch.Size([1, 3, 3])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment