-
Maciej Wielgosz authoredMaciej Wielgosz authored
model.py 5.32 KiB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
# TODO: There are mode conv layers there
class EdgeConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(EdgeConv, self).__init__()
self.in_channels = in_channels
self.conv = nn.Sequential(
nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, k=20):
#batch_size, num_points, in_channels
batch_size, num_points, feature_dim = x.shape
x = x.view(batch_size, num_points,feature_dim )
knn_indices = self.knn(x, k)
knn_gathered = self.gather_neighbors(x, knn_indices)
edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
edge_features = edge_features.view(batch_size, 2*feature_dim, num_points, k)
return self.conv(edge_features).transpose(2, 1)
@staticmethod
def knn(x, k):
"""Find the indices of the k nearest neighbors for each point in the input tensor."""
batch_size, num_points, _ = x.shape
x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
distances = torch.norm(x_expanded - x_transposed, dim=-1)
_, indices = distances.topk(k=k, dim=-1, largest=False)
return indices
def gather_neighbors(self, x, knn_indices):
batch_size, num_points, _ = x.shape
_, _, k = knn_indices.shape
x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, self.in_channels))
return neighbors
class Transform_Net(nn.Module):
def __init__(self, k=3):
super(Transform_Net, self).__init__()
self.k = k
self.conv1 = nn.Conv1d(k, 64, 1)
self.conv2 = nn.Conv1d(64, 128, 1)
self.conv3 = nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k*k)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
def forward(self, x):
# Input shape: (batch_size, k, num_points)
x = x.transpose(2, 1)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)
identity = torch.eye(self.k, dtype=x.dtype, device=x.device)
transform = x.view(-1, self.k, self.k) + identity
transform = transform.transpose(2, 1)
return transform
class DGCNN(nn.Module):
def __init__(self, num_classes):
super(DGCNN, self).__init__()
self.transform_net = Transform_Net()
self.edge_conv1 = EdgeConv(3, 64)
self.edge_conv2 = EdgeConv(64, 128)
self.edge_conv3 = EdgeConv(128, 256)
self.edge_conv4 = EdgeConv(256, 512)
self.bn5 = nn.BatchNorm1d(1024)
self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
self.bn5,
nn.LeakyReLU(negative_slope=0.2))
self.linear1 = nn.Linear(2048, 512, bias=False)
self.fc = nn.Sequential(
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(128, num_classes),
)
def forward(self, x):
# Apply Transform_Net on input point cloud
batch_size = x.size(0)
trans_matrix = self.transform_net(x)
x = torch.bmm(x, trans_matrix)
x1 = self.edge_conv1(x)
x1 = x1.max(dim=-1, keepdim=False)[0]
# print("x1 shape: ", x1.shape)
x2 = self.edge_conv2(x1)
x2 = x2.max(dim=-1, keepdim=False)[0]
# print("x2 shape: ", x2.shape)
x3 = self.edge_conv3(x2)
x3 = x3.max(dim=-1, keepdim=False)[0]
# print("x3 shape: ", x3.shape)
x4 = self.edge_conv4(x3)
x4 = x4.max(dim=-1, keepdim=False)[0]
# print("x4 shape: ", x4.shape)
# x5 = torch.cat((x1, x2, x3, x4), dim=1) # (batch_size, 64+64+128+256, num_points)
# x6 = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
# x7 = F.adaptive_avg_pool1d(x5, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
# x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2)
# x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
x9 = torch.max(x4, dim=1, keepdim=True)[0]
x10 = self.fc(x9.squeeze(1))
return x10