From bfd9359b03817f7d6b3c1d337fe83f962cdf51d6 Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Thu, 6 Apr 2023 14:19:00 +0200
Subject: [PATCH] self attention added

---
 dgcnn/model_class.py | 75 ++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 73 insertions(+), 2 deletions(-)

diff --git a/dgcnn/model_class.py b/dgcnn/model_class.py
index 31878e7..7a9abb3 100644
--- a/dgcnn/model_class.py
+++ b/dgcnn/model_class.py
@@ -2,6 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.init as init
+from torch.nn import MultiheadAttention
 
 # TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
 # TODO: There are mode conv layers there
@@ -48,6 +49,77 @@ class EdgeConvNew(nn.Module):
         _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False)  # (batch_size, num_points, k)
         return idx
 
+# implement self attention
+class SelfAttention(nn.Module):
+    def __init__(self, in_channels, num_heads, dropout):
+        super(SelfAttention, self).__init__()
+        self.in_channels = in_channels
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.self_attention = MultiheadAttention(in_channels, num_heads=num_heads, dropout=dropout)
+
+    def forward(self, x):
+        batch_size = x.size(0)
+        num_points = x.size(2)
+        x = x.view(batch_size, -1, num_points)
+        x = x.permute(1, 0, 2)
+        out, attn = self.self_attention(x, x, x)
+        out = out.permute(1, 0, 2)
+        out = out.view(batch_size, -1, num_points)
+        return out
+
+class EdgeConvNewAtten(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(EdgeConvNewAtten, self).__init__()
+        self.in_channels = in_channels
+        self.conv = nn.Sequential(
+            nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.LeakyReLU(negative_slope=0.2),
+        )
+        self.self_attention = SelfAttention(2*in_channels*20, num_heads=8, dropout=0.1)
+
+    def forward(self, x, k=20):
+        batch_size = x.size(0)
+        num_points = x.size(2)
+        x = x.view(batch_size, -1, num_points)
+        idx = self.knn(x, k=k)   # (batch_size, num_points, k)
+      
+        idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
+
+        idx = idx + idx_base
+
+        idx = idx.view(-1)
+    
+        _, num_dims, _ = x.size()
+
+        x = x.transpose(2, 1).contiguous()   
+        feature = x.view(batch_size*num_points, -1)[idx, :]
+        feature = feature.view(batch_size, num_points, k, num_dims) 
+
+        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
+        
+        feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
+
+        feature = self.conv(feature) # (batch_size, num_dims, num_points, k)
+        # print("feature", feature.shape)
+        feature = feature.permute(0, 2, 1, 3).contiguous()
+        feature = feature.view(batch_size, num_points, -1) 
+
+        # print("feature", feature.shape)
+
+        feature = self.self_attention(feature)  # (batch_size, num_points, out_channels)
+        feature = feature.reshape(batch_size, -1, num_points, k).contiguous()
+        # print("feature", feature.shape)
+    
+        return feature     
+
+    def knn(self, x, k):
+        x = x.transpose(2, 1)
+        pairwise_distance = torch.cdist(x, x, p=2)
+        _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False)  # (batch_size, num_points, k)
+        return idx
+
 class EdgeConv(nn.Module):
     def __init__(self, in_channels, out_channels):
         super(EdgeConv, self).__init__()
@@ -133,7 +205,7 @@ class DgcnnClass(nn.Module):
         super(DgcnnClass, self).__init__()
         self.transform_net = Transform_Net()
         self.edge_conv1 = EdgeConvNew(3, 64)
-        self.edge_conv2 = EdgeConvNew(64, 128)
+        self.edge_conv2 = EdgeConvNewAtten(64, 128)
         self.bn5 = nn.BatchNorm1d(256)
         self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False),
                                    self.bn5,
@@ -160,7 +232,6 @@ class DgcnnClass(nn.Module):
         dim = x.size(2)
 
         x = x.view(batch_size, dim, num_points)
-
       
         x1 = self.edge_conv1(x)
         x1 = x1.max(dim=-1, keepdim=False)[0] 
-- 
GitLab