diff --git a/.gitignore b/.gitignore index 1c2f8207220eba333542658c0a023c2e466185c6..6c1eb2247e93d0fd6f61da12c4c9327db9c93222 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -.vscode -__pycache__/ -modelnet40_normal_resampled/ -outputs/ +.vscode +__pycache__/ +modelnet40_normal_resampled/ +outputs/ log/ \ No newline at end of file diff --git a/README.md b/README.md index 9bac2b6c94d887156cffb33436ff737822d6dd07..d691a91d08baa8ceea55b7bc1b427e15dca5b86b 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,39 @@ -# Pytorch Implementation of Various Point Transformers - -Recently, various methods applied transformers to point clouds: [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688), [Point Transformer (Nico Engel et al.)](https://arxiv.org/abs/2011.00931), [Point Transformer (Hengshuang Zhao et al.)](https://arxiv.org/abs/2012.09164). This repo is a pytorch implementation for these methods and aims to compare them under a fair setting. Currently, all three methods are implemented, while tuning their hyperparameters. - - -## Classification -### Data Preparation -Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `modelnet40_normal_resampled`. - -### Run -Change which method to use in `config/config.yaml` and run -``` -python train.py -``` -### Results -Using Adam with learning rate decay 0.3 for every 50 epochs, train for 200 epochs; data augmentation follows [this repo](https://github.com/yanx27/Pointnet_Pointnet2_pytorch). For Hengshuang and Nico, initial LR is 1e-3 (I would appreciate if someone could fine-tune these hyper-paramters); for Menghao, initial LR is 1e-4, as suggested by the [author](https://github.com/MenghaoGuo). ModelNet40 classification results (instance average) are listed below: -| Model | Accuracy | -|--|--| -| Hengshuang | 89.6| -| Menghao | 92.6 | -| Nico | 85.5 | - -### Miscellaneous -Some code and training settings are borrowed from https://github.com/yanx27/Pointnet_Pointnet2_pytorch. -Code for [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688) is adapted from the author's Jittor implementation https://github.com/MenghaoGuo/PCT. - +# Pytorch Implementation of Various Point Transformers + +Recently, various methods applied transformers to point clouds: [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688), [Point Transformer (Nico Engel et al.)](https://arxiv.org/abs/2011.00931), [Point Transformer (Hengshuang Zhao et al.)](https://arxiv.org/abs/2012.09164). This repo is a pytorch implementation for these methods and aims to compare them under a fair setting. Currently, all three methods are implemented, while tuning their hyperparameters. + + +## Classification +### Data Preparation +Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `modelnet40_normal_resampled`. + +### Run +Change which method to use in `config/cls.yaml` and run +``` +python train_cls.py +``` +### Results +Using Adam with learning rate decay 0.3 for every 50 epochs, train for 200 epochs; data augmentation follows [this repo](https://github.com/yanx27/Pointnet_Pointnet2_pytorch). For Hengshuang and Nico, initial LR is 1e-3 (I would appreciate if someone could fine-tune these hyper-paramters); for Menghao, initial LR is 1e-4, as suggested by the [author](https://github.com/MenghaoGuo). ModelNet40 classification results (instance average) are listed below: +| Model | Accuracy | +|--|--| +| Hengshuang | 89.6| +| Menghao | 92.6 | +| Nico | 85.5 | + + +## Part Segmentation +### Data Preparation +Download alignment **ShapeNet** [here](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip) and save in `data/shapenetcore_partanno_segmentation_benchmark_v0_normal`. + +### Run +Change which method to use in `config/partseg.yaml` and run +``` +python train_partseg.py +``` +### Results +Currently only Hengshuang's method is implemented. + +### Miscellaneous +Some code and training settings are borrowed from https://github.com/yanx27/Pointnet_Pointnet2_pytorch. +Code for [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688) is adapted from the author's Jittor implementation https://github.com/MenghaoGuo/PCT. + diff --git a/config/config.yaml b/config/cls.yaml similarity index 74% rename from config/config.yaml rename to config/cls.yaml index a179f365d498171a3ef7520956b58008c57a8fad..1d7783d0063871b78c3450073297cf58732286ec 100644 --- a/config/config.yaml +++ b/config/cls.yaml @@ -12,8 +12,8 @@ defaults: hydra: run: - dir: log/${model.name} + dir: log/cls/${model.name} sweep: - dir: log + dir: log/cls subdir: ${model.name} \ No newline at end of file diff --git a/config/model/Hengshuang.yaml b/config/model/Hengshuang.yaml index b6a9e188d1a970176b35ba7088565fc164f01908..f3cd12e5496c8499abddf41589c501b1b46e6f18 100644 --- a/config/model/Hengshuang.yaml +++ b/config/model/Hengshuang.yaml @@ -1,5 +1,5 @@ -# @package _group_ -nneighbor: 16 -nblocks: 4 -transformer_dim: 512 +# @package _group_ +nneighbor: 16 +nblocks: 4 +transformer_dim: 512 name: Hengshuang \ No newline at end of file diff --git a/config/model/Menghao.yaml b/config/model/Menghao.yaml index e23bed464de68138a5768abe7e1a0ffac2533708..fea1a8083f378f5d93409de6468f760c7b7c05ec 100644 --- a/config/model/Menghao.yaml +++ b/config/model/Menghao.yaml @@ -1,2 +1,2 @@ -# @package _group_ +# @package _group_ name: Menghao \ No newline at end of file diff --git a/config/model/Nico.yaml b/config/model/Nico.yaml index eaed132363eea5516edda5063df6eb30b107f5ea..9fd6b2e50f84eddcc244a351533f4fab8837e8b5 100644 --- a/config/model/Nico.yaml +++ b/config/model/Nico.yaml @@ -1,9 +1,9 @@ -# @package _group_ -n_head: 8 -m: 4 -k: 64 -global_k: 128 -global_dim: 512 -local_dim: 256 -reduce_dim: 64 +# @package _group_ +n_head: 8 +m: 4 +k: 64 +global_k: 128 +global_dim: 512 +local_dim: 256 +reduce_dim: 64 name: Nico \ No newline at end of file diff --git a/config/partseg.yaml b/config/partseg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e3dd3e59155b058696b7682256a3d92b446da64 --- /dev/null +++ b/config/partseg.yaml @@ -0,0 +1,21 @@ +batch_size: 16 +epoch: 200 +learning_rate: 1e-3 +gpu: 1 +num_point: 1024 +optimizer: Adam +weight_decay: 1e-4 +normal: True +lr_decay: 0.5 +step_size: 20 + +defaults: + - model: Hengshuang + +hydra: + run: + dir: log/partseg/${model.name} + + sweep: + dir: log/partseg + subdir: ${model.name} \ No newline at end of file diff --git a/dataset.py b/dataset.py index dc56eea011c0d01782c0b1f66310aecacd91aaa7..62cff7d5dd36a502d83027a04631281ae27d63a2 100644 --- a/dataset.py +++ b/dataset.py @@ -3,6 +3,7 @@ import os from torch.utils.data import Dataset import torch from pointnet_util import farthest_point_sample, pc_normalize +import json class ModelNetDataLoader(Dataset): @@ -60,6 +61,108 @@ class ModelNetDataLoader(Dataset): return self._get_item(index) +class PartNormalDataset(Dataset): + def __init__(self, root='./data/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False): + self.npoints = npoints + self.root = root + self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') + self.cat = {} + self.normal_channel = normal_channel + + + with open(self.catfile, 'r') as f: + for line in f: + ls = line.strip().split() + self.cat[ls[0]] = ls[1] + self.cat = {k: v for k, v in self.cat.items()} + self.classes_original = dict(zip(self.cat, range(len(self.cat)))) + + if not class_choice is None: + self.cat = {k:v for k,v in self.cat.items() if k in class_choice} + # print(self.cat) + + self.meta = {} + with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: + train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) + with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: + val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) + with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: + test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) + for item in self.cat: + # print('category', item) + self.meta[item] = [] + dir_point = os.path.join(self.root, self.cat[item]) + fns = sorted(os.listdir(dir_point)) + # print(fns[0][0:-4]) + if split == 'trainval': + fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] + elif split == 'train': + fns = [fn for fn in fns if fn[0:-4] in train_ids] + elif split == 'val': + fns = [fn for fn in fns if fn[0:-4] in val_ids] + elif split == 'test': + fns = [fn for fn in fns if fn[0:-4] in test_ids] + else: + print('Unknown split: %s. Exiting..' % (split)) + exit(-1) + + # print(os.path.basename(fns)) + for fn in fns: + token = (os.path.splitext(os.path.basename(fn))[0]) + self.meta[item].append(os.path.join(dir_point, token + '.txt')) + + self.datapath = [] + for item in self.cat: + for fn in self.meta[item]: + self.datapath.append((item, fn)) + + self.classes = {} + for i in self.cat.keys(): + self.classes[i] = self.classes_original[i] + + # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels + self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], + 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], + 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], + 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], + 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} + + # for cat in sorted(self.seg_classes.keys()): + # print(cat, self.seg_classes[cat]) + + self.cache = {} # from index to (point_set, cls, seg) tuple + self.cache_size = 20000 + + + def __getitem__(self, index): + if index in self.cache: + point_set, cls, seg = self.cache[index] + else: + fn = self.datapath[index] + cat = self.datapath[index][0] + cls = self.classes[cat] + cls = np.array([cls]).astype(np.int32) + data = np.loadtxt(fn[1]).astype(np.float32) + if not self.normal_channel: + point_set = data[:, 0:3] + else: + point_set = data[:, 0:6] + seg = data[:, -1].astype(np.int32) + if len(self.cache) < self.cache_size: + self.cache[index] = (point_set, cls, seg) + point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) + + choice = np.random.choice(len(seg), self.npoints, replace=True) + # resample + point_set = point_set[choice, :] + seg = seg[choice] + + return point_set, cls, seg + + def __len__(self): + return len(self.datapath) + + if __name__ == '__main__': data = ModelNetDataLoader('modelnet40_normal_resampled/', split='train', uniform=False, normal_channel=True) DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True) diff --git a/models/Hengshuang/model.py b/models/Hengshuang/model.py index 5b97e807dac412916f2959a9aabece44b2f9ce7d..161d7bc5eaf4b34fe076a06f0ee902b2dcc00487 100644 --- a/models/Hengshuang/model.py +++ b/models/Hengshuang/model.py @@ -1,49 +1,141 @@ -import torch -import torch.nn as nn -from pointnet_util import PointNetSetAbstraction -from .transformer import TransformerBlock - - -class TransitionDown(nn.Module): - def __init__(self, k, nneighbor, channels) -> None: - super().__init__() - self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True) - - def forward(self, xyz, points): - return self.sa(xyz, points) - - -class PointTransformer(nn.Module): - def __init__(self, cfg) -> None: - super().__init__() - npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim - self.fc1 = nn.Sequential( - nn.Linear(d_points, 32), - nn.ReLU(), - nn.Linear(32, 32) - ) - self.transformer1 = TransformerBlock(32, cfg.model.transformer_dim, nneighbor) - self.transition_downs = nn.ModuleList() - self.transformers = nn.ModuleList() - for i in range(nblocks): - channel = 32 * 2 ** (i + 1) - self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 3, channel, channel])) - self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) - - self.fc2 = nn.Sequential( - nn.Linear(32 * 2 ** nblocks, 256), - nn.ReLU(), - nn.Linear(256, 64), - nn.ReLU(), - nn.Linear(64, n_c) - ) - self.nblocks = nblocks - - def forward(self, x): - xyz = x[..., :3] - points = self.transformer1(xyz, self.fc1(x))[0] - for i in range(self.nblocks): - xyz, points = self.transition_downs[i](xyz, points) - points = self.transformers[i](xyz, points)[0] - res = self.fc2(points.mean(1)) - return res \ No newline at end of file +import torch +import torch.nn as nn +from pointnet_util import PointNetFeaturePropagation, PointNetSetAbstraction +from .transformer import TransformerBlock + + +class TransitionDown(nn.Module): + def __init__(self, k, nneighbor, channels): + super().__init__() + self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True) + + def forward(self, xyz, points): + return self.sa(xyz, points) + + +class TransitionUp(nn.Module): + def __init__(self, dim1, dim2, dim_out): + class SwapAxes(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.transpose(1, 2) + + super().__init__() + self.fc1 = nn.Sequential( + nn.Linear(dim1, dim_out), + SwapAxes(), + nn.BatchNorm1d(dim_out), # TODO + SwapAxes(), + nn.ReLU(), + ) + self.fc2 = nn.Sequential( + nn.Linear(dim2, dim_out), + SwapAxes(), + nn.BatchNorm1d(dim_out), # TODO + SwapAxes(), + nn.ReLU(), + ) + self.fp = PointNetFeaturePropagation(-1, []) + + def forward(self, xyz1, points1, xyz2, points2): + feats1 = self.fc1(points1) + feats2 = self.fc2(points2) + feats1 = self.fp(xyz2.transpose(1, 2), xyz1.transpose(1, 2), None, feats1.transpose(1, 2)).transpose(1, 2) + return feats1 + feats2 + + +class Backbone(nn.Module): + def __init__(self, cfg): + super().__init__() + npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim + self.fc1 = nn.Sequential( + nn.Linear(d_points, 32), + nn.ReLU(), + nn.Linear(32, 32) + ) + self.transformer1 = TransformerBlock(32, cfg.model.transformer_dim, nneighbor) + self.transition_downs = nn.ModuleList() + self.transformers = nn.ModuleList() + for i in range(nblocks): + channel = 32 * 2 ** (i + 1) + self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 3, channel, channel])) + self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) + self.nblocks = nblocks + + def forward(self, x): + xyz = x[..., :3] + points = self.transformer1(xyz, self.fc1(x))[0] + + xyz_and_feats = [(xyz, points)] + for i in range(self.nblocks): + xyz, points = self.transition_downs[i](xyz, points) + points = self.transformers[i](xyz, points)[0] + xyz_and_feats.append((xyz, points)) + return points, xyz_and_feats + + +class PointTransformerCls(nn.Module): + def __init__(self, cfg): + super().__init__() + self.backbone = Backbone(cfg) + npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim + self.fc2 = nn.Sequential( + nn.Linear(32 * 2 ** nblocks, 256), + nn.ReLU(), + nn.Linear(256, 64), + nn.ReLU(), + nn.Linear(64, n_c) + ) + self.nblocks = nblocks + + def forward(self, x): + points, _ = self.backbone(x) + res = self.fc2(points.mean(1)) + return res + + +class PointTransformerSeg(nn.Module): + def __init__(self, cfg): + super().__init__() + self.backbone = Backbone(cfg) + npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim + self.fc2 = nn.Sequential( + nn.Linear(32 * 2 ** nblocks, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 32 * 2 ** nblocks) + ) + self.transformer2 = TransformerBlock(32 * 2 ** nblocks, cfg.model.transformer_dim, nneighbor) + self.nblocks = nblocks + self.transition_ups = nn.ModuleList() + self.transformers = nn.ModuleList() + for i in reversed(range(nblocks)): + channel = 32 * 2 ** i + self.transition_ups.append(TransitionUp(channel * 2, channel, channel)) + self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) + + self.fc3 = nn.Sequential( + nn.Linear(32, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU(), + nn.Linear(64, n_c) + ) + + def forward(self, x): + points, xyz_and_feats = self.backbone(x) + xyz = xyz_and_feats[-1][0] + points = self.transformer2(xyz, self.fc2(points))[0] + + for i in range(self.nblocks): + points = self.transition_ups[i](xyz, points, xyz_and_feats[- i - 2][0], xyz_and_feats[- i - 2][1]) + xyz = xyz_and_feats[- i - 2][0] + points = self.transformers[i](xyz, points)[0] + + return self.fc3(points) + + + \ No newline at end of file diff --git a/models/Hengshuang/transformer.py b/models/Hengshuang/transformer.py index 0b688c5f9af67e65c77b17cb3572284d64da901b..942fb5526a817ad1fcb0ce16e821ab13caf28318 100644 --- a/models/Hengshuang/transformer.py +++ b/models/Hengshuang/transformer.py @@ -1,45 +1,45 @@ -from pointnet_util import index_points, square_distance -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - -class TransformerBlock(nn.Module): - def __init__(self, d_points, d_model, k) -> None: - super().__init__() - self.fc1 = nn.Linear(d_points, d_model) - self.fc2 = nn.Linear(d_model, d_points) - self.fc_delta = nn.Sequential( - nn.Linear(3, d_model), - nn.ReLU(), - nn.Linear(d_model, d_model) - ) - self.fc_gamma = nn.Sequential( - nn.Linear(d_model, d_model), - nn.ReLU(), - nn.Linear(d_model, d_model) - ) - self.w_qs = nn.Linear(d_model, d_model, bias=False) - self.w_ks = nn.Linear(d_model, d_model, bias=False) - self.w_vs = nn.Linear(d_model, d_model, bias=False) - self.k = k - - # xyz: b x n x 3, features: b x n x f - def forward(self, xyz, features): - dists = square_distance(xyz, xyz) - knn_idx = dists.argsort()[:, :, :self.k] # b x n x k - knn_xyz = index_points(xyz, knn_idx) - - pre = features - x = self.fc1(features) - q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx) - - pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz) # b x n x k x f - - attn = self.fc_gamma(q[:, :, None] - k + pos_enc) - attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2) # b x n x k x f - - res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) - res = self.fc2(res) + pre - return res, attn +from pointnet_util import index_points, square_distance +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class TransformerBlock(nn.Module): + def __init__(self, d_points, d_model, k) -> None: + super().__init__() + self.fc1 = nn.Linear(d_points, d_model) + self.fc2 = nn.Linear(d_model, d_points) + self.fc_delta = nn.Sequential( + nn.Linear(3, d_model), + nn.ReLU(), + nn.Linear(d_model, d_model) + ) + self.fc_gamma = nn.Sequential( + nn.Linear(d_model, d_model), + nn.ReLU(), + nn.Linear(d_model, d_model) + ) + self.w_qs = nn.Linear(d_model, d_model, bias=False) + self.w_ks = nn.Linear(d_model, d_model, bias=False) + self.w_vs = nn.Linear(d_model, d_model, bias=False) + self.k = k + + # xyz: b x n x 3, features: b x n x f + def forward(self, xyz, features): + dists = square_distance(xyz, xyz) + knn_idx = dists.argsort()[:, :, :self.k] # b x n x k + knn_xyz = index_points(xyz, knn_idx) + + pre = features + x = self.fc1(features) + q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx) + + pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz) # b x n x k x f + + attn = self.fc_gamma(q[:, :, None] - k + pos_enc) + attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2) # b x n x k x f + + res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) + res = self.fc2(res) + pre + return res, attn \ No newline at end of file diff --git a/models/Menghao/model.py b/models/Menghao/model.py index ed7481c44b154a7a5c504f459c2005afe7c1e637..f60f606f40722bbc63aaad67c507624fcef76226 100644 --- a/models/Menghao/model.py +++ b/models/Menghao/model.py @@ -1,159 +1,159 @@ -import torch -import torch.nn as nn -from pointnet_util import farthest_point_sample, index_points, square_distance - - -def sample_and_group(npoint, nsample, xyz, points): - B, N, C = xyz.shape - S = npoint - - fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] - - new_xyz = index_points(xyz, fps_idx) - new_points = index_points(points, fps_idx) - - dists = square_distance(new_xyz, xyz) # B x npoint x N - idx = dists.argsort()[:, :, :nsample] # B x npoint x K - - grouped_points = index_points(points, idx) - grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1) - new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1) - return new_xyz, new_points - - -class Local_op(nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) - self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm1d(out_channels) - self.bn2 = nn.BatchNorm1d(out_channels) - self.relu = nn.ReLU() - - def forward(self, x): - b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) - x = x.permute(0, 1, 3, 2) - x = x.reshape(-1, d, s) - batch_size, _, N = x.size() - x = self.relu(self.bn1(self.conv1(x))) # B, D, N - x = self.relu(self.bn2(self.conv2(x))) # B, D, N - x = torch.max(x, 2)[0] - x = x.view(batch_size, -1) - x = x.reshape(b, n, -1).permute(0, 2, 1) - return x - - -class SA_Layer(nn.Module): - def __init__(self, channels): - super().__init__() - self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) - self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) - self.q_conv.weight = self.k_conv.weight - self.v_conv = nn.Conv1d(channels, channels, 1) - self.trans_conv = nn.Conv1d(channels, channels, 1) - self.after_norm = nn.BatchNorm1d(channels) - self.act = nn.ReLU() - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x): - x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c - x_k = self.k_conv(x)# b, c, n - x_v = self.v_conv(x) - energy = x_q @ x_k # b, n, n - attention = self.softmax(energy) - attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) - x_r = x_v @ attention # b, c, n - x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) - x = x + x_r - return x - - -class StackedAttention(nn.Module): - def __init__(self, channels=256): - super().__init__() - self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) - self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) - - self.bn1 = nn.BatchNorm1d(channels) - self.bn2 = nn.BatchNorm1d(channels) - - self.sa1 = SA_Layer(channels) - self.sa2 = SA_Layer(channels) - self.sa3 = SA_Layer(channels) - self.sa4 = SA_Layer(channels) - - self.relu = nn.ReLU() - - def forward(self, x): - # - # b, 3, npoint, nsample - # conv2d 3 -> 128 channels 1, 1 - # b * npoint, c, nsample - # permute reshape - batch_size, _, N = x.size() - - x = self.relu(self.bn1(self.conv1(x))) # B, D, N - x = self.relu(self.bn2(self.conv2(x))) - - x1 = self.sa1(x) - x2 = self.sa2(x1) - x3 = self.sa3(x2) - x4 = self.sa4(x3) - - x = torch.cat((x1, x2, x3, x4), dim=1) - - return x - - -class PointTransformer(nn.Module): - def __init__(self, cfg): - super().__init__() - output_channels = cfg.num_class - d_points = cfg.input_dim - self.conv1 = nn.Conv1d(d_points, 64, kernel_size=1, bias=False) - self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm1d(64) - self.bn2 = nn.BatchNorm1d(64) - self.gather_local_0 = Local_op(in_channels=128, out_channels=128) - self.gather_local_1 = Local_op(in_channels=256, out_channels=256) - self.pt_last = StackedAttention() - - self.relu = nn.ReLU() - self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False), - nn.BatchNorm1d(1024), - nn.LeakyReLU(negative_slope=0.2)) - - self.linear1 = nn.Linear(1024, 512, bias=False) - self.bn6 = nn.BatchNorm1d(512) - self.dp1 = nn.Dropout(p=0.5) - self.linear2 = nn.Linear(512, 256) - self.bn7 = nn.BatchNorm1d(256) - self.dp2 = nn.Dropout(p=0.5) - self.linear3 = nn.Linear(256, output_channels) - - def forward(self, x): - xyz = x[..., :3] - x = x.permute(0, 2, 1) - batch_size, _, _ = x.size() - x = self.relu(self.bn1(self.conv1(x))) # B, D, N - x = self.relu(self.bn2(self.conv2(x))) # B, D, N - x = x.permute(0, 2, 1) - new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x) - feature_0 = self.gather_local_0(new_feature) - feature = feature_0.permute(0, 2, 1) - new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature) - feature_1 = self.gather_local_1(new_feature) - - x = self.pt_last(feature_1) - x = torch.cat([x, feature_1], dim=1) - x = self.conv_fuse(x) - x = torch.max(x, 2)[0] - x = x.view(batch_size, -1) - - x = self.relu(self.bn6(self.linear1(x))) - x = self.dp1(x) - x = self.relu(self.bn7(self.linear2(x))) - x = self.dp2(x) - x = self.linear3(x) - +import torch +import torch.nn as nn +from pointnet_util import farthest_point_sample, index_points, square_distance + + +def sample_and_group(npoint, nsample, xyz, points): + B, N, C = xyz.shape + S = npoint + + fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] + + new_xyz = index_points(xyz, fps_idx) + new_points = index_points(points, fps_idx) + + dists = square_distance(new_xyz, xyz) # B x npoint x N + idx = dists.argsort()[:, :, :nsample] # B x npoint x K + + grouped_points = index_points(points, idx) + grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1) + new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1) + return new_xyz, new_points + + +class Local_op(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(out_channels) + self.bn2 = nn.BatchNorm1d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) + x = x.permute(0, 1, 3, 2) + x = x.reshape(-1, d, s) + batch_size, _, N = x.size() + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = self.relu(self.bn2(self.conv2(x))) # B, D, N + x = torch.max(x, 2)[0] + x = x.view(batch_size, -1) + x = x.reshape(b, n, -1).permute(0, 2, 1) + return x + + +class SA_Layer(nn.Module): + def __init__(self, channels): + super().__init__() + self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) + self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) + self.q_conv.weight = self.k_conv.weight + self.v_conv = nn.Conv1d(channels, channels, 1) + self.trans_conv = nn.Conv1d(channels, channels, 1) + self.after_norm = nn.BatchNorm1d(channels) + self.act = nn.ReLU() + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c + x_k = self.k_conv(x)# b, c, n + x_v = self.v_conv(x) + energy = x_q @ x_k # b, n, n + attention = self.softmax(energy) + attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) + x_r = x_v @ attention # b, c, n + x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) + x = x + x_r + return x + + +class StackedAttention(nn.Module): + def __init__(self, channels=256): + super().__init__() + self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) + + self.bn1 = nn.BatchNorm1d(channels) + self.bn2 = nn.BatchNorm1d(channels) + + self.sa1 = SA_Layer(channels) + self.sa2 = SA_Layer(channels) + self.sa3 = SA_Layer(channels) + self.sa4 = SA_Layer(channels) + + self.relu = nn.ReLU() + + def forward(self, x): + # + # b, 3, npoint, nsample + # conv2d 3 -> 128 channels 1, 1 + # b * npoint, c, nsample + # permute reshape + batch_size, _, N = x.size() + + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = self.relu(self.bn2(self.conv2(x))) + + x1 = self.sa1(x) + x2 = self.sa2(x1) + x3 = self.sa3(x2) + x4 = self.sa4(x3) + + x = torch.cat((x1, x2, x3, x4), dim=1) + + return x + + +class PointTransformerCls(nn.Module): + def __init__(self, cfg): + super().__init__() + output_channels = cfg.num_class + d_points = cfg.input_dim + self.conv1 = nn.Conv1d(d_points, 64, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(64) + self.bn2 = nn.BatchNorm1d(64) + self.gather_local_0 = Local_op(in_channels=128, out_channels=128) + self.gather_local_1 = Local_op(in_channels=256, out_channels=256) + self.pt_last = StackedAttention() + + self.relu = nn.ReLU() + self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False), + nn.BatchNorm1d(1024), + nn.LeakyReLU(negative_slope=0.2)) + + self.linear1 = nn.Linear(1024, 512, bias=False) + self.bn6 = nn.BatchNorm1d(512) + self.dp1 = nn.Dropout(p=0.5) + self.linear2 = nn.Linear(512, 256) + self.bn7 = nn.BatchNorm1d(256) + self.dp2 = nn.Dropout(p=0.5) + self.linear3 = nn.Linear(256, output_channels) + + def forward(self, x): + xyz = x[..., :3] + x = x.permute(0, 2, 1) + batch_size, _, _ = x.size() + x = self.relu(self.bn1(self.conv1(x))) # B, D, N + x = self.relu(self.bn2(self.conv2(x))) # B, D, N + x = x.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x) + feature_0 = self.gather_local_0(new_feature) + feature = feature_0.permute(0, 2, 1) + new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature) + feature_1 = self.gather_local_1(new_feature) + + x = self.pt_last(feature_1) + x = torch.cat([x, feature_1], dim=1) + x = self.conv_fuse(x) + x = torch.max(x, 2)[0] + x = x.view(batch_size, -1) + + x = self.relu(self.bn6(self.linear1(x))) + x = self.dp1(x) + x = self.relu(self.bn7(self.linear2(x))) + x = self.dp2(x) + x = self.linear3(x) + return x \ No newline at end of file diff --git a/models/Nico/model.py b/models/Nico/model.py index 568fa6ecefc780b9074ec8f8a8c105c332a05e8e..e291554ca55829c170bcdb37959a874bc0ef7b54 100644 --- a/models/Nico/model.py +++ b/models/Nico/model.py @@ -82,7 +82,7 @@ class GlobalFeatureGeneration(nn.Module): return out, x -class PointTransformer(nn.Module): +class PointTransformerCls(nn.Module): def __init__(self, cfg): super().__init__() d_model_l, d_model_g, d_reduce, m, k, n_c, d_points, n_head \ diff --git a/train.py b/train_cls.py similarity index 95% rename from train.py rename to train_cls.py index 2b6c1cb702fa13402465740c589872cd62d7bf36..22e820089fde78567e7b7bc91ee0e70fb5e8f2da 100644 --- a/train.py +++ b/train_cls.py @@ -41,7 +41,7 @@ def test(model, loader, num_class=40): return instance_acc, class_acc -@hydra.main(config_path='config', config_name='config') +@hydra.main(config_path='config', config_name='cls') def main(args): omegaconf.OmegaConf.set_struct(args, False) @@ -65,7 +65,7 @@ def main(args): args.input_dim = 6 if args.normal else 3 shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.') - classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformer')(args).cuda() + classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerCls')(args).cuda() criterion = torch.nn.CrossEntropyLoss() try: diff --git a/train_partseg.py b/train_partseg.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e9eaccb53823e5893e0dc2c766a466007b04cd --- /dev/null +++ b/train_partseg.py @@ -0,0 +1,242 @@ +""" +Author: Benny +Date: Nov 2019 +""" +import argparse +import os +import torch +import datetime +import logging +import sys +import importlib +import shutil +import provider +import numpy as np + +from pathlib import Path +from tqdm import tqdm +from dataset import PartNormalDataset +import hydra +import omegaconf + + +seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], + 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], + 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], + 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} +seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} +for cat in seg_classes.keys(): + for label in seg_classes[cat]: + seg_label_to_cat[label] = cat + + +def inplace_relu(m): + classname = m.__class__.__name__ + if classname.find('ReLU') != -1: + m.inplace=True + +def to_categorical(y, num_classes): + """ 1-hot encodes a tensor """ + new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] + if (y.is_cuda): + return new_y.cuda() + return new_y + +@hydra.main(config_path='config', config_name='partseg') +def main(args): + omegaconf.OmegaConf.set_struct(args, False) + + '''HYPER PARAMETER''' + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + logger = logging.getLogger(__name__) + + print(args.pretty()) + + root = hydra.utils.to_absolute_path('data/shapenetcore_partanno_segmentation_benchmark_v0_normal/') + + TRAIN_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='trainval', normal_channel=args.normal) + trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True) + TEST_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='test', normal_channel=args.normal) + testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10) + + '''MODEL LOADING''' + args.input_dim = (6 if args.normal else 3) + 16 + args.num_class = 50 + num_category = 16 + num_part = args.num_class + shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.') + + classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerSeg')(args).cuda() + criterion = torch.nn.CrossEntropyLoss() + + try: + checkpoint = torch.load('best_model.pth') + start_epoch = checkpoint['epoch'] + classifier.load_state_dict(checkpoint['model_state_dict']) + logger.info('Use pretrain model') + except: + logger.info('No existing model, starting training from scratch...') + start_epoch = 0 + + if args.optimizer == 'Adam': + optimizer = torch.optim.Adam( + classifier.parameters(), + lr=args.learning_rate, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=args.weight_decay + ) + else: + optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9) + + def bn_momentum_adjust(m, momentum): + if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): + m.momentum = momentum + + LEARNING_RATE_CLIP = 1e-5 + MOMENTUM_ORIGINAL = 0.1 + MOMENTUM_DECCAY = 0.5 + MOMENTUM_DECCAY_STEP = args.step_size + + best_acc = 0 + global_epoch = 0 + best_class_avg_iou = 0 + best_inctance_avg_iou = 0 + + for epoch in range(start_epoch, args.epoch): + mean_correct = [] + + logger.info('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) + '''Adjust learning rate and BN momentum''' + lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP) + logger.info('Learning rate:%f' % lr) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP)) + if momentum < 0.01: + momentum = 0.01 + print('BN momentum updated to: %f' % momentum) + classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum)) + classifier = classifier.train() + + '''learning one epoch''' + for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9): + points = points.data.numpy() + points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3]) + points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) + points = torch.Tensor(points) + + points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda() + optimizer.zero_grad() + + seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1)) + seg_pred = seg_pred.contiguous().view(-1, num_part) + target = target.view(-1, 1)[:, 0] + pred_choice = seg_pred.data.max(1)[1] + + correct = pred_choice.eq(target.data).cpu().sum() + mean_correct.append(correct.item() / (args.batch_size * args.num_point)) + loss = criterion(seg_pred, target) + loss.backward() + optimizer.step() + + train_instance_acc = np.mean(mean_correct) + logger.info('Train accuracy is: %.5f' % train_instance_acc) + + with torch.no_grad(): + test_metrics = {} + total_correct = 0 + total_seen = 0 + total_seen_class = [0 for _ in range(num_part)] + total_correct_class = [0 for _ in range(num_part)] + shape_ious = {cat: [] for cat in seg_classes.keys()} + seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} + + for cat in seg_classes.keys(): + for label in seg_classes[cat]: + seg_label_to_cat[label] = cat + + classifier = classifier.eval() + + for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): + cur_batch_size, NUM_POINT, _ = points.size() + points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda() + seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1)) + cur_pred_val = seg_pred.cpu().data.numpy() + cur_pred_val_logits = cur_pred_val + cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32) + target = target.cpu().data.numpy() + + for i in range(cur_batch_size): + cat = seg_label_to_cat[target[i, 0]] + logits = cur_pred_val_logits[i, :, :] + cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] + + correct = np.sum(cur_pred_val == target) + total_correct += correct + total_seen += (cur_batch_size * NUM_POINT) + + for l in range(num_part): + total_seen_class[l] += np.sum(target == l) + total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l))) + + for i in range(cur_batch_size): + segp = cur_pred_val[i, :] + segl = target[i, :] + cat = seg_label_to_cat[segl[0]] + part_ious = [0.0 for _ in range(len(seg_classes[cat]))] + for l in seg_classes[cat]: + if (np.sum(segl == l) == 0) and ( + np.sum(segp == l) == 0): # part is not present, no prediction as well + part_ious[l - seg_classes[cat][0]] = 1.0 + else: + part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( + np.sum((segl == l) | (segp == l))) + shape_ious[cat].append(np.mean(part_ious)) + + all_shape_ious = [] + for cat in shape_ious.keys(): + for iou in shape_ious[cat]: + all_shape_ious.append(iou) + shape_ious[cat] = np.mean(shape_ious[cat]) + mean_shape_ious = np.mean(list(shape_ious.values())) + test_metrics['accuracy'] = total_correct / float(total_seen) + test_metrics['class_avg_accuracy'] = np.mean( + np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) + for cat in sorted(shape_ious.keys()): + logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat])) + test_metrics['class_avg_iou'] = mean_shape_ious + test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) + + logger.info('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( + epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou'])) + if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou): + logger.info('Save model...') + savepath = 'best_model.pth' + logger.info('Saving at %s' % savepath) + state = { + 'epoch': epoch, + 'train_acc': train_instance_acc, + 'test_acc': test_metrics['accuracy'], + 'class_avg_iou': test_metrics['class_avg_iou'], + 'inctance_avg_iou': test_metrics['inctance_avg_iou'], + 'model_state_dict': classifier.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + } + torch.save(state, savepath) + logger.info('Saving model....') + + if test_metrics['accuracy'] > best_acc: + best_acc = test_metrics['accuracy'] + if test_metrics['class_avg_iou'] > best_class_avg_iou: + best_class_avg_iou = test_metrics['class_avg_iou'] + if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: + best_inctance_avg_iou = test_metrics['inctance_avg_iou'] + logger.info('Best accuracy is: %.5f' % best_acc) + logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou) + logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou) + global_epoch += 1 + + +if __name__ == '__main__': + main() \ No newline at end of file