Skip to content
Snippets Groups Projects
Commit bfd9359b authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

self attention added

parent 6c1e540d
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn import MultiheadAttention
# TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
# TODO: There are mode conv layers there
......@@ -48,6 +49,77 @@ class EdgeConvNew(nn.Module):
_, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False) # (batch_size, num_points, k)
return idx
# implement self attention
class SelfAttention(nn.Module):
def __init__(self, in_channels, num_heads, dropout):
super(SelfAttention, self).__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.dropout = dropout
self.self_attention = MultiheadAttention(in_channels, num_heads=num_heads, dropout=dropout)
def forward(self, x):
batch_size = x.size(0)
num_points = x.size(2)
x = x.view(batch_size, -1, num_points)
x = x.permute(1, 0, 2)
out, attn = self.self_attention(x, x, x)
out = out.permute(1, 0, 2)
out = out.view(batch_size, -1, num_points)
return out
class EdgeConvNewAtten(nn.Module):
def __init__(self, in_channels, out_channels):
super(EdgeConvNewAtten, self).__init__()
self.in_channels = in_channels
self.conv = nn.Sequential(
nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(negative_slope=0.2),
)
self.self_attention = SelfAttention(2*in_channels*20, num_heads=8, dropout=0.1)
def forward(self, x, k=20):
batch_size = x.size(0)
num_points = x.size(2)
x = x.view(batch_size, -1, num_points)
idx = self.knn(x, k=k) # (batch_size, num_points, k)
idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
_, num_dims, _ = x.size()
x = x.transpose(2, 1).contiguous()
feature = x.view(batch_size*num_points, -1)[idx, :]
feature = feature.view(batch_size, num_points, k, num_dims)
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
feature = self.conv(feature) # (batch_size, num_dims, num_points, k)
# print("feature", feature.shape)
feature = feature.permute(0, 2, 1, 3).contiguous()
feature = feature.view(batch_size, num_points, -1)
# print("feature", feature.shape)
feature = self.self_attention(feature) # (batch_size, num_points, out_channels)
feature = feature.reshape(batch_size, -1, num_points, k).contiguous()
# print("feature", feature.shape)
return feature
def knn(self, x, k):
x = x.transpose(2, 1)
pairwise_distance = torch.cdist(x, x, p=2)
_, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False) # (batch_size, num_points, k)
return idx
class EdgeConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(EdgeConv, self).__init__()
......@@ -133,7 +205,7 @@ class DgcnnClass(nn.Module):
super(DgcnnClass, self).__init__()
self.transform_net = Transform_Net()
self.edge_conv1 = EdgeConvNew(3, 64)
self.edge_conv2 = EdgeConvNew(64, 128)
self.edge_conv2 = EdgeConvNewAtten(64, 128)
self.bn5 = nn.BatchNorm1d(256)
self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False),
self.bn5,
......@@ -160,7 +232,6 @@ class DgcnnClass(nn.Module):
dim = x.size(2)
x = x.view(batch_size, dim, num_points)
x1 = self.edge_conv1(x)
x1 = x1.max(dim=-1, keepdim=False)[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment