diff --git a/README.md b/README.md index e39eadc67890857f31b24547a9d23bbcd669cc2f..69af2963ad4866d999e9a95330413a81616d8169 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Pytorch Implementation of Various Point Transformers -Recently, various methods applied transformers to point clouds: [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688), [Point Transformer (Nico Engel et al.)](https://arxiv.org/abs/2011.00931), [Point Transformer (Hengshuang Zhao et al.)](https://arxiv.org/abs/2012.09164). This repo is a pytorch implementation for these methods and aims to compare them under a fair setting. Currently, Point Transformer (Nico Engel et al.) is implemented. +Recently, various methods applied transformers to point clouds: [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688), [Point Transformer (Nico Engel et al.)](https://arxiv.org/abs/2011.00931), [Point Transformer (Hengshuang Zhao et al.)](https://arxiv.org/abs/2012.09164). This repo is a pytorch implementation for these methods and aims to compare them under a fair setting. Currently, Point Transformer (Nico Engel et al.) and Point Transformer (Hengshuang Zhao et al.) are implemented. ## Classification @@ -8,6 +8,7 @@ Recently, various methods applied transformers to point clouds: [PCT: Point Clou Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `modelnet40_normal_resampled`. ### Run +Change which method to use in `config/config.yaml` and run ``` python train.py ``` @@ -17,5 +18,4 @@ TBA ### Miscellaneous Some code and training settings are borrowed from https://github.com/yanx27/Pointnet_Pointnet2_pytorch. ## TODOs -- [ ] implement Point Transformer (Hengshuang Zhao et al.) - [ ] implement PCT: Point Cloud Transformer (Meng-Hao Guo et al.) \ No newline at end of file diff --git a/config.yaml b/config/config.yaml similarity index 59% rename from config.yaml rename to config/config.yaml index f6e59a8aa7832d6fc9726e1e46701241473701c4..2c933de487aa4d434add44c92dce10e4c2eafe62 100644 --- a/config.yaml +++ b/config/config.yaml @@ -1,4 +1,4 @@ -batch_size: 24 +batch_size: 16 epoch: 200 learning_rate: 1e-3 gpu: 1 @@ -7,6 +7,9 @@ optimizer: Adam weight_decay: 1e-4 normal: True +defaults: + - model: Nico + hydra: run: - dir: outputs \ No newline at end of file + dir: outputs/${model} \ No newline at end of file diff --git a/config/model/Hengshuang.yaml b/config/model/Hengshuang.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b6a9e188d1a970176b35ba7088565fc164f01908 --- /dev/null +++ b/config/model/Hengshuang.yaml @@ -0,0 +1,5 @@ +# @package _group_ +nneighbor: 16 +nblocks: 4 +transformer_dim: 512 +name: Hengshuang \ No newline at end of file diff --git a/config/model/Nico.yaml b/config/model/Nico.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eaed132363eea5516edda5063df6eb30b107f5ea --- /dev/null +++ b/config/model/Nico.yaml @@ -0,0 +1,9 @@ +# @package _group_ +n_head: 8 +m: 4 +k: 64 +global_k: 128 +global_dim: 512 +local_dim: 256 +reduce_dim: 64 +name: Nico \ No newline at end of file diff --git a/models/Hengshuang/model.py b/models/Hengshuang/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5b97e807dac412916f2959a9aabece44b2f9ce7d --- /dev/null +++ b/models/Hengshuang/model.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +from pointnet_util import PointNetSetAbstraction +from .transformer import TransformerBlock + + +class TransitionDown(nn.Module): + def __init__(self, k, nneighbor, channels) -> None: + super().__init__() + self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True) + + def forward(self, xyz, points): + return self.sa(xyz, points) + + +class PointTransformer(nn.Module): + def __init__(self, cfg) -> None: + super().__init__() + npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim + self.fc1 = nn.Sequential( + nn.Linear(d_points, 32), + nn.ReLU(), + nn.Linear(32, 32) + ) + self.transformer1 = TransformerBlock(32, cfg.model.transformer_dim, nneighbor) + self.transition_downs = nn.ModuleList() + self.transformers = nn.ModuleList() + for i in range(nblocks): + channel = 32 * 2 ** (i + 1) + self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 3, channel, channel])) + self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) + + self.fc2 = nn.Sequential( + nn.Linear(32 * 2 ** nblocks, 256), + nn.ReLU(), + nn.Linear(256, 64), + nn.ReLU(), + nn.Linear(64, n_c) + ) + self.nblocks = nblocks + + def forward(self, x): + xyz = x[..., :3] + points = self.transformer1(xyz, self.fc1(x))[0] + for i in range(self.nblocks): + xyz, points = self.transition_downs[i](xyz, points) + points = self.transformers[i](xyz, points)[0] + res = self.fc2(points.mean(1)) + return res \ No newline at end of file diff --git a/models/Hengshuang/transformer.py b/models/Hengshuang/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0b688c5f9af67e65c77b17cb3572284d64da901b --- /dev/null +++ b/models/Hengshuang/transformer.py @@ -0,0 +1,45 @@ +from pointnet_util import index_points, square_distance +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class TransformerBlock(nn.Module): + def __init__(self, d_points, d_model, k) -> None: + super().__init__() + self.fc1 = nn.Linear(d_points, d_model) + self.fc2 = nn.Linear(d_model, d_points) + self.fc_delta = nn.Sequential( + nn.Linear(3, d_model), + nn.ReLU(), + nn.Linear(d_model, d_model) + ) + self.fc_gamma = nn.Sequential( + nn.Linear(d_model, d_model), + nn.ReLU(), + nn.Linear(d_model, d_model) + ) + self.w_qs = nn.Linear(d_model, d_model, bias=False) + self.w_ks = nn.Linear(d_model, d_model, bias=False) + self.w_vs = nn.Linear(d_model, d_model, bias=False) + self.k = k + + # xyz: b x n x 3, features: b x n x f + def forward(self, xyz, features): + dists = square_distance(xyz, xyz) + knn_idx = dists.argsort()[:, :, :self.k] # b x n x k + knn_xyz = index_points(xyz, knn_idx) + + pre = features + x = self.fc1(features) + q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx) + + pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz) # b x n x k x f + + attn = self.fc_gamma(q[:, :, None] - k + pos_enc) + attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2) # b x n x k x f + + res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) + res = self.fc2(res) + pre + return res, attn + \ No newline at end of file diff --git a/model.py b/models/Nico/model.py similarity index 60% rename from model.py rename to models/Nico/model.py index 5ec388f5f50169ab5375accddd39541144828834..568fa6ecefc780b9074ec8f8a8c105c332a05e8e 100644 --- a/model.py +++ b/models/Nico/model.py @@ -1,22 +1,34 @@ import torch import torch.nn as nn from pointnet_util import PointNetSetAbstractionMsg -from transformer import MultiHeadAttention +from .transformer import MultiHeadAttention class SortNet(nn.Module): def __init__(self, d_model, d_points=6, k=64): super().__init__() - self.fc = nn.Linear(d_model, 1) + self.fc = nn.Sequential( + nn.Linear(d_model, 256), + nn.ReLU(), + nn.Linear(256, 64), + nn.ReLU(), + nn.Linear(64, 1) + ) self.sa = PointNetSetAbstractionMsg(k, [0.1, 0.2, 0.4], [16, 32, 128], d_model, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) - self.fc_agg = nn.Linear(64 + 128 + 128, d_model - 1 - d_points) + self.fc_agg = nn.Sequential( + nn.Linear(64 + 128 + 128, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, d_model - 1 - d_points), + ) self.k = k self.d_points = d_points def forward(self, points, features): score = self.fc(features) topk_idx = torch.topk(score[..., 0], self.k, 1)[1] - features_abs = self.sa(points[..., :3].transpose(1, 2), features.transpose(1, 2), topk_idx)[1].transpose(1, 2) + features_abs = self.sa(points[..., :3], features, topk_idx)[1] res = torch.cat((self.fc_agg(features_abs), torch.gather(score, 1, topk_idx[..., None].expand(-1, -1, score.size(-1))), torch.gather(points, 1, topk_idx[..., None].expand(-1, -1, points.size(-1)))), -1) @@ -26,7 +38,13 @@ class SortNet(nn.Module): class LocalFeatureGeneration(nn.Module): def __init__(self, d_model, m, k, d_points=6, n_head=4): super().__init__() - self.fc = nn.Linear(d_points, d_model) + self.fc = nn.Sequential( + nn.Linear(d_points, 64), + nn.ReLU(), + nn.Linear(64, 256), + nn.ReLU(), + nn.Linear(256, d_model) + ) self.sortnets = nn.ModuleList([SortNet(d_model, k=k)] * m) self.att = MultiHeadAttention(n_head, d_model, d_model, d_model // n_head, d_model // n_head) @@ -40,25 +58,45 @@ class LocalFeatureGeneration(nn.Module): class GlobalFeatureGeneration(nn.Module): def __init__(self, d_model, k, d_points=6, n_head=4): super().__init__() - self.fc = nn.Linear(d_points, d_model) + self.fc = nn.Sequential( + nn.Linear(d_points, 64), + nn.ReLU(), + nn.Linear(64, 256), + nn.ReLU(), + nn.Linear(256, d_model) + ) self.sa = PointNetSetAbstractionMsg(k, [0.1, 0.2, 0.4], [16, 32, 128], d_model, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) self.att = MultiHeadAttention(n_head, d_model, d_model, d_model // n_head, d_model // n_head) - self.fc_agg = nn.Linear(64 + 128 + 128, d_model) + self.fc_agg = nn.Sequential( + nn.Linear(64 + 128 + 128, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, d_model), + ) def forward(self, points): x = self.fc(points) x, _ = self.att(x, x, x) - out = self.fc_agg(self.sa(points[..., :3].transpose(1, 2), x.transpose(1, 2))[1].transpose(1, 2)) + out = self.fc_agg(self.sa(points[..., :3], x)[1]) return out, x class PointTransformer(nn.Module): - def __init__(self, d_model_l, d_model_g, d_reduce, m, k, n_c, d_points=6, n_head=4): + def __init__(self, cfg): super().__init__() + d_model_l, d_model_g, d_reduce, m, k, n_c, d_points, n_head \ + = cfg.model.global_dim, cfg.model.local_dim, cfg.model.reduce_dim, cfg.model.m, cfg.model.k, cfg.num_class, cfg.input_dim, cfg.model.n_head self.lfg = LocalFeatureGeneration(d_model=d_model_l, m=m, k=k, d_points=d_points) - self.gfg = GlobalFeatureGeneration(d_model=d_model_g, k=128, d_points=d_points) + self.gfg = GlobalFeatureGeneration(d_model=d_model_g, k=cfg.model.global_k, d_points=d_points) self.lg_att = MultiHeadAttention(n_head, d_model_l, d_model_g, d_model_l // n_head, d_model_l // n_head) - self.fc = nn.Linear(d_model_l, d_reduce) + self.fc = nn.Sequential( + nn.Linear(d_model_l, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, d_reduce), + ) self.fc_cls = nn.Sequential( nn.Linear(k * m * d_reduce, 1024), nn.ReLU(), diff --git a/transformer.py b/models/Nico/transformer.py similarity index 100% rename from transformer.py rename to models/Nico/transformer.py diff --git a/pointnet_util.py b/pointnet_util.py index 58aef06fa260c54ad0f0b8dc7ebcb8fced964e87..8e042a973cf968540bec6f73fed09503617db819 100644 --- a/pointnet_util.py +++ b/pointnet_util.py @@ -96,7 +96,7 @@ def query_ball_point(radius, nsample, xyz, new_xyz): return group_idx -def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): +def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False): """ Input: npoint: @@ -110,11 +110,15 @@ def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): """ B, N, C = xyz.shape S = npoint - fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] + fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] torch.cuda.empty_cache() new_xyz = index_points(xyz, fps_idx) torch.cuda.empty_cache() - idx = query_ball_point(radius, nsample, xyz, new_xyz) + if knn: + dists = square_distance(new_xyz, xyz) # B x npoint x N + idx = dists.argsort()[:, :, :nsample] # B x npoint x K + else: + idx = query_ball_point(radius, nsample, xyz, new_xyz) torch.cuda.empty_cache() grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] torch.cuda.empty_cache() @@ -153,11 +157,12 @@ def sample_and_group_all(xyz, points): class PointNetSetAbstraction(nn.Module): - def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): + def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False): super(PointNetSetAbstraction, self).__init__() self.npoint = npoint self.radius = radius self.nsample = nsample + self.knn = knn self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel @@ -170,20 +175,16 @@ class PointNetSetAbstraction(nn.Module): def forward(self, xyz, points): """ Input: - xyz: input points position data, [B, C, N] - points: input points data, [B, D, N] + xyz: input points position data, [B, N, C] + points: input points data, [B, N, C] Return: - new_xyz: sampled points position data, [B, C, S] - new_points_concat: sample points feature data, [B, D', S] + new_xyz: sampled points position data, [B, S, C] + new_points_concat: sample points feature data, [B, S, D'] """ - xyz = xyz.permute(0, 2, 1) - if points is not None: - points = points.permute(0, 2, 1) - if self.group_all: new_xyz, new_points = sample_and_group_all(xyz, points) else: - new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) + new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn) # new_xyz: sampled points position data, [B, npoint, C] # new_points: sampled points data, [B, npoint, nsample, C+D] new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] @@ -191,17 +192,17 @@ class PointNetSetAbstraction(nn.Module): bn = self.mlp_bns[i] new_points = F.relu(bn(conv(new_points))) - new_points = torch.max(new_points, 2)[0] - new_xyz = new_xyz.permute(0, 2, 1) + new_points = torch.max(new_points, 2)[0].transpose(1, 2) return new_xyz, new_points class PointNetSetAbstractionMsg(nn.Module): - def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): + def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False): super(PointNetSetAbstractionMsg, self).__init__() self.npoint = npoint self.radius_list = radius_list self.nsample_list = nsample_list + self.knn = knn self.conv_blocks = nn.ModuleList() self.bn_blocks = nn.ModuleList() for i in range(len(mlp_list)): @@ -224,9 +225,6 @@ class PointNetSetAbstractionMsg(nn.Module): new_xyz: sampled points position data, [B, C, S] new_points_concat: sample points feature data, [B, D', S] """ - xyz = xyz.permute(0, 2, 1) - if points is not None: - points = points.permute(0, 2, 1) B, N, C = xyz.shape S = self.npoint @@ -234,7 +232,11 @@ class PointNetSetAbstractionMsg(nn.Module): new_points_list = [] for i, radius in enumerate(self.radius_list): K = self.nsample_list[i] - group_idx = query_ball_point(radius, K, xyz, new_xyz) + if self.knn: + dists = square_distance(new_xyz, xyz) # B x npoint x N + group_idx = dists.argsort()[:, :, :K] # B x npoint x K + else: + group_idx = query_ball_point(radius, K, xyz, new_xyz) grouped_xyz = index_points(xyz, group_idx) grouped_xyz -= new_xyz.view(B, S, 1, C) if points is not None: @@ -251,11 +253,11 @@ class PointNetSetAbstractionMsg(nn.Module): new_points = torch.max(grouped_points, 2)[0] # [B, D', S] new_points_list.append(new_points) - new_xyz = new_xyz.permute(0, 2, 1) - new_points_concat = torch.cat(new_points_list, dim=1) + new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2) return new_xyz, new_points_concat +# NoteL this function swaps N and C class PointNetFeaturePropagation(nn.Module): def __init__(self, in_channel, mlp): super(PointNetFeaturePropagation, self).__init__() diff --git a/train.py b/train.py index ea8629208786c44fa8169fb9c05438ce8ded1414..2b6c1cb702fa13402465740c589872cd62d7bf36 100644 --- a/train.py +++ b/train.py @@ -16,8 +16,7 @@ import provider import importlib import shutil import hydra -from model import PointTransformer - +import omegaconf def test(model, loader, num_class=40): @@ -41,8 +40,10 @@ def test(model, loader, num_class=40): instance_acc = np.mean(mean_correct) return instance_acc, class_acc -@hydra.main(config_name='config') + +@hydra.main(config_path='config', config_name='config') def main(args): + omegaconf.OmegaConf.set_struct(args, False) '''HYPER PARAMETER''' os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) @@ -60,10 +61,11 @@ def main(args): testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4) '''MODEL LOADING''' - num_class = 40 - shutil.copy(hydra.utils.to_absolute_path('model.py'), '.') + args.num_class = 40 + args.input_dim = 6 if args.normal else 3 + shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.') - classifier = PointTransformer(512, 256, 64, m=4, k=64, n_c=40, d_points=6).cuda() + classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformer')(args).cuda() criterion = torch.nn.CrossEntropyLoss() try: @@ -87,7 +89,7 @@ def main(args): else: optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.3) global_epoch = 0 global_step = 0 best_instance_acc = 0.0