Skip to content
Snippets Groups Projects
model.py 5.32 KiB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

# TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
# TODO: There are mode conv layers there


class EdgeConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EdgeConv, self).__init__()
        self.in_channels = in_channels
        self.conv = nn.Sequential(
            nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x, k=20):
        #batch_size, num_points, in_channels

        batch_size, num_points, feature_dim = x.shape
        x = x.view(batch_size, num_points,feature_dim )
        knn_indices = self.knn(x, k)
        knn_gathered = self.gather_neighbors(x, knn_indices)
        edge_features = torch.cat((knn_gathered - x.unsqueeze(2).repeat(1,1,k,1), knn_gathered), dim=1)
        edge_features = edge_features.view(batch_size, 2*feature_dim, num_points, k)
        return self.conv(edge_features).transpose(2, 1)

    @staticmethod
    def knn(x, k):
        """Find the indices of the k nearest neighbors for each point in the input tensor."""
        batch_size, num_points, _ = x.shape
        x_expanded = x.unsqueeze(2).expand(-1, -1, num_points, -1)
        x_transposed = x.unsqueeze(1).expand(-1, num_points, -1, -1)
        distances = torch.norm(x_expanded - x_transposed, dim=-1)
        _, indices = distances.topk(k=k, dim=-1, largest=False)
        return indices
    
    def gather_neighbors(self, x, knn_indices):
        batch_size, num_points, _ = x.shape
        _, _, k = knn_indices.shape
        x_expanded = x.unsqueeze(2).repeat(1, 1, num_points, 1)
        neighbors = torch.gather(x_expanded, 2, knn_indices.view(batch_size, num_points, k, 1).repeat(1, 1, 1, self.in_channels))
        return neighbors
    
    
class Transform_Net(nn.Module):
    def __init__(self, k=3):
        super(Transform_Net, self).__init__()
        self.k = k

        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)

        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        # Input shape: (batch_size, k, num_points)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))

        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        identity = torch.eye(self.k, dtype=x.dtype, device=x.device)
        transform = x.view(-1, self.k, self.k) + identity

        transform = transform.transpose(2, 1)

        return transform

class DGCNN(nn.Module):
    def __init__(self, num_classes):
        super(DGCNN, self).__init__()
        self.transform_net = Transform_Net()
        self.edge_conv1 = EdgeConv(3, 64)
        self.edge_conv2 = EdgeConv(64, 128)
        self.edge_conv3 = EdgeConv(128, 256)
        self.edge_conv4 = EdgeConv(256, 512)
        self.bn5 = nn.BatchNorm1d(1024)
        self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.linear1 = nn.Linear(2048, 512, bias=False)

        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        # Apply Transform_Net on input point cloud
        batch_size = x.size(0)

        trans_matrix = self.transform_net(x)
        x = torch.bmm(x, trans_matrix)
        x1 = self.edge_conv1(x)
        x1 = x1.max(dim=-1, keepdim=False)[0] 
        # print("x1 shape: ", x1.shape)
        x2 = self.edge_conv2(x1)
        x2 = x2.max(dim=-1, keepdim=False)[0]
        # print("x2 shape: ", x2.shape)
        x3 = self.edge_conv3(x2)
        x3 = x3.max(dim=-1, keepdim=False)[0]
        # print("x3 shape: ", x3.shape)
        x4 = self.edge_conv4(x3)
        x4 = x4.max(dim=-1, keepdim=False)[0]
        # print("x4 shape: ", x4.shape)
        # x5 = torch.cat((x1, x2, x3, x4), dim=1)  # (batch_size, 64+64+128+256, num_points)
        # x6 = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        # x7 = F.adaptive_avg_pool1d(x5, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        # x8 = torch.cat((x6, x7), 1)              # (batch_size, emb_dims*2)

        # x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)

        x9 = torch.max(x4, dim=1, keepdim=True)[0]
        x10 = self.fc(x9.squeeze(1))
        return x10