From bfd9359b03817f7d6b3c1d337fe83f962cdf51d6 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Thu, 6 Apr 2023 14:19:00 +0200 Subject: [PATCH] self attention added --- dgcnn/model_class.py | 75 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/dgcnn/model_class.py b/dgcnn/model_class.py index 31878e7..7a9abb3 100644 --- a/dgcnn/model_class.py +++ b/dgcnn/model_class.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init +from torch.nn import MultiheadAttention # TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166 # TODO: There are mode conv layers there @@ -48,6 +49,77 @@ class EdgeConvNew(nn.Module): _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False) # (batch_size, num_points, k) return idx +# implement self attention +class SelfAttention(nn.Module): + def __init__(self, in_channels, num_heads, dropout): + super(SelfAttention, self).__init__() + self.in_channels = in_channels + self.num_heads = num_heads + self.dropout = dropout + self.self_attention = MultiheadAttention(in_channels, num_heads=num_heads, dropout=dropout) + + def forward(self, x): + batch_size = x.size(0) + num_points = x.size(2) + x = x.view(batch_size, -1, num_points) + x = x.permute(1, 0, 2) + out, attn = self.self_attention(x, x, x) + out = out.permute(1, 0, 2) + out = out.view(batch_size, -1, num_points) + return out + +class EdgeConvNewAtten(nn.Module): + def __init__(self, in_channels, out_channels): + super(EdgeConvNewAtten, self).__init__() + self.in_channels = in_channels + self.conv = nn.Sequential( + nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2), + ) + self.self_attention = SelfAttention(2*in_channels*20, num_heads=8, dropout=0.1) + + def forward(self, x, k=20): + batch_size = x.size(0) + num_points = x.size(2) + x = x.view(batch_size, -1, num_points) + idx = self.knn(x, k=k) # (batch_size, num_points, k) + + idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points + + idx = idx + idx_base + + idx = idx.view(-1) + + _, num_dims, _ = x.size() + + x = x.transpose(2, 1).contiguous() + feature = x.view(batch_size*num_points, -1)[idx, :] + feature = feature.view(batch_size, num_points, k, num_dims) + + x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) + + feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() + + feature = self.conv(feature) # (batch_size, num_dims, num_points, k) + # print("feature", feature.shape) + feature = feature.permute(0, 2, 1, 3).contiguous() + feature = feature.view(batch_size, num_points, -1) + + # print("feature", feature.shape) + + feature = self.self_attention(feature) # (batch_size, num_points, out_channels) + feature = feature.reshape(batch_size, -1, num_points, k).contiguous() + # print("feature", feature.shape) + + return feature + + def knn(self, x, k): + x = x.transpose(2, 1) + pairwise_distance = torch.cdist(x, x, p=2) + _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False) # (batch_size, num_points, k) + return idx + class EdgeConv(nn.Module): def __init__(self, in_channels, out_channels): super(EdgeConv, self).__init__() @@ -133,7 +205,7 @@ class DgcnnClass(nn.Module): super(DgcnnClass, self).__init__() self.transform_net = Transform_Net() self.edge_conv1 = EdgeConvNew(3, 64) - self.edge_conv2 = EdgeConvNew(64, 128) + self.edge_conv2 = EdgeConvNewAtten(64, 128) self.bn5 = nn.BatchNorm1d(256) self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False), self.bn5, @@ -160,7 +232,6 @@ class DgcnnClass(nn.Module): dim = x.size(2) x = x.view(batch_size, dim, num_points) - x1 = self.edge_conv1(x) x1 = x1.max(dim=-1, keepdim=False)[0] -- GitLab