Skip to content
Snippets Groups Projects
Commit 1f92d285 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

updated shape in the model for dgcnn classification on shapeNet

parent e9981ec3
Branches
No related tags found
No related merge requests found
......@@ -28,4 +28,6 @@ lightning_logs
cifar-10-batches-py
*.ckpt
*.gz
MNIST
\ No newline at end of file
MNIST
results
ModelNet10
\ No newline at end of file
......@@ -16,7 +16,6 @@ def main():
print(input_tensor.shape)
out = dgcnn(input_tensor)
print(out.shape)
print(out)
if __name__ == '__main__':
main()
%% Cell type:code id: tags:
``` python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class EdgeConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(EdgeConv, self).__init__()
self.in_channels = in_channels
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, k=20):
#batch_size, num_points, in_channels
batch_size, num_points, feature_dim = x.shape
x = x.view(batch_size, num_points,feature_dim )
knn_indices = self.knn(x, k)
knn_gathered = self.gather_neighbors(x, knn_indices)
edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
edge_features = edge_features.view(batch_size, feature_dim, 2*num_points, k)
edge_features = edge_features.view(batch_size, 2*feature_dim, num_points, k)
return self.conv(edge_features).transpose(2, 1)
@staticmethod
def knn(x, k):
"""Find the indices of the k nearest neighbors for each point in the input tensor."""
batch_size, num_points, _ = x.shape
x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
distances = torch.norm(x_expanded - x_transposed, dim=-1)
_, indices = distances.topk(k=k, dim=-1, largest=False)
return indices
@staticmethod
def gather_neighbors(x, knn_indices):
def gather_neighbors(self, x, knn_indices):
batch_size, num_points, _ = x.shape
_, _, k = knn_indices.shape
x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, 3))
neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, self.in_channels))
return neighbors
# Generate a sample input vector
batch_size = 1
in_channels = 3
num_points = 1024
point_cloud = torch.randn(batch_size, in_channels, num_points)
# Create an EdgeConv layer
out_channels = 32 # You can choose the desired number of output channels
edge_conv_layer = EdgeConv(in_channels, out_channels)
edge_conv_layer.eval()
# Pass the sample input vector through the EdgeConv layer
k = 10
output = edge_conv_layer(point_cloud.view(batch_size, num_points, in_channels), k)
print("Output shape:", output.shape)
```
%% Output
Output shape: torch.Size([1, 2048, 32, 10])
Output shape: torch.Size([1, 1024, 32, 10])
%% Cell type:code id: tags:
``` python
def knn(x, k):
inner = -2 * torch.matmul(x.transpose(2, 1), x)
xx = torch.sum(x ** 2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1)
knn_indices = pairwise_distance.topk(k=k, dim=-1)[1]
return knn_indices
import torch
# Set sample input parameters
batch_size = 2
num_points = 1024
feature_dim = 3
k = 20
# Create a random input tensor
x = torch.randn(batch_size, num_points, feature_dim)
print("Sample input vector x:")
print(x.shape)
# Compute the k nearest neighbors
knn_indices = knn(x, k)
print("k nearest neighbors indices:")
print(knn_indices.shape)
```
%% Cell type:code id: tags:
``` python
import torch
def knn(x, k):
"""Find the indices of the k nearest neighbors for each point in the input tensor."""
batch_size, num_points, _ = x.shape
x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
distances = torch.norm(x_expanded - x_transposed, dim=-1)
_, indices = distances.topk(k=k, dim=-1, largest=False)
return indices
# Set sample input parameters
batch_size = 1
num_points = 1024
feature_dim = 3
k = 20
# Create a random input tensor
x = torch.randn(batch_size, num_points, feature_dim)
print("Sample input vector x:")
print(x.shape)
# Compute the k nearest neighbors
knn_indices = knn(x, k)
print("k nearest neighbors indices:")
print(knn_indices.shape)
```
%% Output
Sample input vector x:
torch.Size([1, 1024, 3])
k nearest neighbors indices:
torch.Size([1, 1024, 20])
%% Cell type:code id: tags:
``` python
x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
print("x_expanded:", x_expanded.shape)
print(x_expanded[0, 0, 0, :])
print(x[0, 0, :])
```
%% Output
x_expanded: torch.Size([1, 1024, 1024, 3])
tensor([-0.7148, -0.1634, -1.3527])
tensor([-0.7148, -0.1634, -1.3527])
%% Cell type:code id: tags:
``` python
print('knn_indices:', knn_indices.shape)
# Gather the neighbors
neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, 3))
print("neighbors:", neighbors.shape)
```
%% Output
knn_indices: torch.Size([1, 1024, 20])
neighbors: torch.Size([1, 1024, 20, 3])
......
......@@ -11,20 +11,21 @@ class EdgeConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(EdgeConv, self).__init__()
self.in_channels = in_channels
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU() # TODO: replace with leaky relu
nn.ReLU()
)
def forward(self, x, k=20):
#batch_size, num_points, in_channels
batch_size, num_points, feature_dim = x.shape
x = x.view(batch_size, num_points,feature_dim )
knn_indices = self.knn(x, k)
knn_gathered = self.gather_neighbors(x, knn_indices)
edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
edge_features = edge_features.view(batch_size, feature_dim, 2*num_points, k)
edge_features = edge_features.view(batch_size, 2*feature_dim, num_points, k)
return self.conv(edge_features).transpose(2, 1)
@staticmethod
......@@ -93,6 +94,12 @@ class DGCNN(nn.Module):
self.edge_conv2 = EdgeConv(64, 128)
self.edge_conv3 = EdgeConv(128, 256)
self.edge_conv4 = EdgeConv(256, 512)
self.bn5 = nn.BatchNorm1d(1024)
self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
self.bn5,
nn.LeakyReLU(negative_slope=0.2))
self.linear1 = nn.Linear(2048, 512, bias=False)
self.fc = nn.Sequential(
nn.Linear(512, 256),
nn.BatchNorm1d(256),
......@@ -107,17 +114,29 @@ class DGCNN(nn.Module):
def forward(self, x):
# Apply Transform_Net on input point cloud
batch_size = x.size(0)
trans_matrix = self.transform_net(x)
x = torch.bmm(x, trans_matrix)
x1 = self.edge_conv1(x)
x1 = x1.max(dim=-1, keepdim=False)[0]
# print("x1 shape: ", x1.shape)
x2 = self.edge_conv2(x1)
x2 = x2.max(dim=-1, keepdim=False)[0]
# print("x2 shape: ", x2.shape)
x3 = self.edge_conv3(x2)
x3 = x3.max(dim=-1, keepdim=False)[0]
# print("x3 shape: ", x3.shape)
x4 = self.edge_conv4(x3)
x4 = x4.max(dim=-1, keepdim=False)[0]
x = torch.max(x4, dim=1, keepdim=True)[0]
x = self.fc(x.squeeze(1))
return x
\ No newline at end of file
# print("x4 shape: ", x4.shape)
# x5 = torch.cat((x1, x2, x3, x4), dim=1) # (batch_size, 64+64+128+256, num_points)
# x6 = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
# x7 = F.adaptive_avg_pool1d(x5, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
# x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2)
# x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
x9 = torch.max(x4, dim=1, keepdim=True)[0]
x10 = self.fc(x9.squeeze(1))
return x10
\ No newline at end of file
......@@ -16,7 +16,7 @@ class ShapenetData(object):
small_data=False,
small_data_size=10,
just_one_class=False,
norm=True,
norm=False,
data_augmentation=False
) -> None:
......@@ -115,6 +115,9 @@ class ShapenetData(object):
elif self.split == 'val':
point_set = np.loadtxt(self.val_data_file[index]).astype(np.float32)
# get labels
labels = point_set[:, 6]
# get just the points
point_set = point_set[:, 0:3]
......@@ -122,14 +125,18 @@ class ShapenetData(object):
if self.norm:
point_set = self.normalize(point_set)
# choice = np.random.choice(len(point_set), self.npoints, replace=True)
choice = np.random.choice(len(point_set), self.npoints, replace=True)
# chose the first npoints
choice = np.arange(self.npoints)
point_set = point_set[choice, :]
point_set = point_set.astype(np.float32)
return point_set
# get the labels
labels = labels[choice]
labels = labels.astype(np.int64)
return point_set, labels
def __len__(self):
if self.split == 'train':
......@@ -142,7 +149,10 @@ class ShapenetData(object):
if __name__ == "__main__":
shapenet_data = ShapenetData(root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', split='train')
shapenet_data = ShapenetData(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
split='train'
)
# print(shapenet_data.train_file_list)
print(shapenet_data.get_seg_classes('Car'))
print(shapenet_data.get_class_names())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment