diff --git a/.gitignore b/.gitignore
index 1c2f8207220eba333542658c0a023c2e466185c6..6c1eb2247e93d0fd6f61da12c4c9327db9c93222 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,5 @@
-.vscode
-__pycache__/
-modelnet40_normal_resampled/
-outputs/
+.vscode
+__pycache__/
+modelnet40_normal_resampled/
+outputs/
 log/
\ No newline at end of file
diff --git a/README.md b/README.md
index 9bac2b6c94d887156cffb33436ff737822d6dd07..d691a91d08baa8ceea55b7bc1b427e15dca5b86b 100644
--- a/README.md
+++ b/README.md
@@ -1,26 +1,39 @@
-# Pytorch Implementation of Various Point Transformers
-
-Recently, various methods applied transformers to point clouds: [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688), [Point Transformer (Nico Engel et al.)](https://arxiv.org/abs/2011.00931), [Point Transformer (Hengshuang Zhao et al.)](https://arxiv.org/abs/2012.09164). This repo is a pytorch implementation for these methods and aims to compare them under a fair setting. Currently, all three methods are implemented, while tuning their hyperparameters.
-
-
-## Classification
-### Data Preparation
-Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `modelnet40_normal_resampled`.
-
-### Run
-Change which method to use in `config/config.yaml` and run
-```
-python train.py
-```
-### Results
-Using Adam with learning rate decay 0.3 for every 50 epochs, train for 200 epochs; data augmentation follows [this repo](https://github.com/yanx27/Pointnet_Pointnet2_pytorch). For Hengshuang and Nico, initial LR is 1e-3 (I would appreciate if someone could fine-tune these hyper-paramters); for Menghao, initial LR is 1e-4, as suggested by the [author](https://github.com/MenghaoGuo). ModelNet40 classification results (instance average) are listed below:
-| Model | Accuracy |
-|--|--|
-| Hengshuang |  89.6|
-| Menghao | 92.6 |
-| Nico |  85.5 |
-
-### Miscellaneous
-Some code and training settings are borrowed from https://github.com/yanx27/Pointnet_Pointnet2_pytorch.
-Code for [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688) is adapted from the author's Jittor implementation https://github.com/MenghaoGuo/PCT.
-
+# Pytorch Implementation of Various Point Transformers
+
+Recently, various methods applied transformers to point clouds: [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688), [Point Transformer (Nico Engel et al.)](https://arxiv.org/abs/2011.00931), [Point Transformer (Hengshuang Zhao et al.)](https://arxiv.org/abs/2012.09164). This repo is a pytorch implementation for these methods and aims to compare them under a fair setting. Currently, all three methods are implemented, while tuning their hyperparameters.
+
+
+## Classification
+### Data Preparation
+Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `modelnet40_normal_resampled`.
+
+### Run
+Change which method to use in `config/cls.yaml` and run
+```
+python train_cls.py
+```
+### Results
+Using Adam with learning rate decay 0.3 for every 50 epochs, train for 200 epochs; data augmentation follows [this repo](https://github.com/yanx27/Pointnet_Pointnet2_pytorch). For Hengshuang and Nico, initial LR is 1e-3 (I would appreciate if someone could fine-tune these hyper-paramters); for Menghao, initial LR is 1e-4, as suggested by the [author](https://github.com/MenghaoGuo). ModelNet40 classification results (instance average) are listed below:
+| Model | Accuracy |
+|--|--|
+| Hengshuang |  89.6|
+| Menghao | 92.6 |
+| Nico |  85.5 |
+
+
+## Part Segmentation
+### Data Preparation
+Download alignment **ShapeNet** [here](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip) and save in `data/shapenetcore_partanno_segmentation_benchmark_v0_normal`.
+
+### Run
+Change which method to use in `config/partseg.yaml` and run
+```
+python train_partseg.py
+```
+### Results
+Currently only Hengshuang's method is implemented.
+
+### Miscellaneous
+Some code and training settings are borrowed from https://github.com/yanx27/Pointnet_Pointnet2_pytorch.
+Code for [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688) is adapted from the author's Jittor implementation https://github.com/MenghaoGuo/PCT.
+
diff --git a/config/config.yaml b/config/cls.yaml
similarity index 74%
rename from config/config.yaml
rename to config/cls.yaml
index a179f365d498171a3ef7520956b58008c57a8fad..1d7783d0063871b78c3450073297cf58732286ec 100644
--- a/config/config.yaml
+++ b/config/cls.yaml
@@ -12,8 +12,8 @@ defaults:
 
 hydra:
   run:
-    dir: log/${model.name}
+    dir: log/cls/${model.name}
 
   sweep:
-    dir: log
+    dir: log/cls
     subdir: ${model.name}
\ No newline at end of file
diff --git a/config/model/Hengshuang.yaml b/config/model/Hengshuang.yaml
index b6a9e188d1a970176b35ba7088565fc164f01908..f3cd12e5496c8499abddf41589c501b1b46e6f18 100644
--- a/config/model/Hengshuang.yaml
+++ b/config/model/Hengshuang.yaml
@@ -1,5 +1,5 @@
-# @package _group_
-nneighbor: 16
-nblocks: 4
-transformer_dim: 512
+# @package _group_
+nneighbor: 16
+nblocks: 4
+transformer_dim: 512
 name: Hengshuang
\ No newline at end of file
diff --git a/config/model/Menghao.yaml b/config/model/Menghao.yaml
index e23bed464de68138a5768abe7e1a0ffac2533708..fea1a8083f378f5d93409de6468f760c7b7c05ec 100644
--- a/config/model/Menghao.yaml
+++ b/config/model/Menghao.yaml
@@ -1,2 +1,2 @@
-# @package _group_
+# @package _group_
 name: Menghao
\ No newline at end of file
diff --git a/config/model/Nico.yaml b/config/model/Nico.yaml
index eaed132363eea5516edda5063df6eb30b107f5ea..9fd6b2e50f84eddcc244a351533f4fab8837e8b5 100644
--- a/config/model/Nico.yaml
+++ b/config/model/Nico.yaml
@@ -1,9 +1,9 @@
-# @package _group_
-n_head: 8
-m: 4
-k: 64
-global_k: 128
-global_dim: 512
-local_dim: 256
-reduce_dim: 64
+# @package _group_
+n_head: 8
+m: 4
+k: 64
+global_k: 128
+global_dim: 512
+local_dim: 256
+reduce_dim: 64
 name: Nico
\ No newline at end of file
diff --git a/config/partseg.yaml b/config/partseg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e3dd3e59155b058696b7682256a3d92b446da64
--- /dev/null
+++ b/config/partseg.yaml
@@ -0,0 +1,21 @@
+batch_size: 16
+epoch: 200
+learning_rate: 1e-3
+gpu: 1
+num_point: 1024
+optimizer: Adam
+weight_decay: 1e-4
+normal: True
+lr_decay: 0.5
+step_size: 20
+
+defaults:
+  - model: Hengshuang
+
+hydra:
+  run:
+    dir: log/partseg/${model.name}
+
+  sweep:
+    dir: log/partseg
+    subdir: ${model.name}
\ No newline at end of file
diff --git a/dataset.py b/dataset.py
index dc56eea011c0d01782c0b1f66310aecacd91aaa7..62cff7d5dd36a502d83027a04631281ae27d63a2 100644
--- a/dataset.py
+++ b/dataset.py
@@ -3,6 +3,7 @@ import os
 from torch.utils.data import Dataset
 import torch
 from pointnet_util import farthest_point_sample, pc_normalize
+import json
 
 
 class ModelNetDataLoader(Dataset):
@@ -60,6 +61,108 @@ class ModelNetDataLoader(Dataset):
         return self._get_item(index)
 
 
+class PartNormalDataset(Dataset):
+    def __init__(self, root='./data/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False):
+        self.npoints = npoints
+        self.root = root
+        self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
+        self.cat = {}
+        self.normal_channel = normal_channel
+
+
+        with open(self.catfile, 'r') as f:
+            for line in f:
+                ls = line.strip().split()
+                self.cat[ls[0]] = ls[1]
+        self.cat = {k: v for k, v in self.cat.items()}
+        self.classes_original = dict(zip(self.cat, range(len(self.cat))))
+
+        if not class_choice is  None:
+            self.cat = {k:v for k,v in self.cat.items() if k in class_choice}
+        # print(self.cat)
+
+        self.meta = {}
+        with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
+            train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
+        with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
+            val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
+        with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
+            test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
+        for item in self.cat:
+            # print('category', item)
+            self.meta[item] = []
+            dir_point = os.path.join(self.root, self.cat[item])
+            fns = sorted(os.listdir(dir_point))
+            # print(fns[0][0:-4])
+            if split == 'trainval':
+                fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
+            elif split == 'train':
+                fns = [fn for fn in fns if fn[0:-4] in train_ids]
+            elif split == 'val':
+                fns = [fn for fn in fns if fn[0:-4] in val_ids]
+            elif split == 'test':
+                fns = [fn for fn in fns if fn[0:-4] in test_ids]
+            else:
+                print('Unknown split: %s. Exiting..' % (split))
+                exit(-1)
+
+            # print(os.path.basename(fns))
+            for fn in fns:
+                token = (os.path.splitext(os.path.basename(fn))[0])
+                self.meta[item].append(os.path.join(dir_point, token + '.txt'))
+
+        self.datapath = []
+        for item in self.cat:
+            for fn in self.meta[item]:
+                self.datapath.append((item, fn))
+
+        self.classes = {}
+        for i in self.cat.keys():
+            self.classes[i] = self.classes_original[i]
+
+        # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
+        self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
+                            'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
+                            'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
+                            'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
+                            'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
+
+        # for cat in sorted(self.seg_classes.keys()):
+        #     print(cat, self.seg_classes[cat])
+
+        self.cache = {}  # from index to (point_set, cls, seg) tuple
+        self.cache_size = 20000
+
+
+    def __getitem__(self, index):
+        if index in self.cache:
+            point_set, cls, seg = self.cache[index]
+        else:
+            fn = self.datapath[index]
+            cat = self.datapath[index][0]
+            cls = self.classes[cat]
+            cls = np.array([cls]).astype(np.int32)
+            data = np.loadtxt(fn[1]).astype(np.float32)
+            if not self.normal_channel:
+                point_set = data[:, 0:3]
+            else:
+                point_set = data[:, 0:6]
+            seg = data[:, -1].astype(np.int32)
+            if len(self.cache) < self.cache_size:
+                self.cache[index] = (point_set, cls, seg)
+        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
+
+        choice = np.random.choice(len(seg), self.npoints, replace=True)
+        # resample
+        point_set = point_set[choice, :]
+        seg = seg[choice]
+
+        return point_set, cls, seg
+
+    def __len__(self):
+        return len(self.datapath)
+
+
 if __name__ == '__main__':
     data = ModelNetDataLoader('modelnet40_normal_resampled/', split='train', uniform=False, normal_channel=True)
     DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
diff --git a/models/Hengshuang/model.py b/models/Hengshuang/model.py
index 5b97e807dac412916f2959a9aabece44b2f9ce7d..161d7bc5eaf4b34fe076a06f0ee902b2dcc00487 100644
--- a/models/Hengshuang/model.py
+++ b/models/Hengshuang/model.py
@@ -1,49 +1,141 @@
-import torch
-import torch.nn as nn
-from pointnet_util import PointNetSetAbstraction
-from .transformer import TransformerBlock
-
-
-class TransitionDown(nn.Module):
-    def __init__(self, k, nneighbor, channels) -> None:
-        super().__init__()
-        self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True)
-        
-    def forward(self, xyz, points):
-        return self.sa(xyz, points)
-    
-    
-class PointTransformer(nn.Module):
-    def __init__(self, cfg) -> None:
-        super().__init__()
-        npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim
-        self.fc1 = nn.Sequential(
-            nn.Linear(d_points, 32),
-            nn.ReLU(),
-            nn.Linear(32, 32)
-        )
-        self.transformer1 = TransformerBlock(32, cfg.model.transformer_dim, nneighbor)
-        self.transition_downs = nn.ModuleList()
-        self.transformers = nn.ModuleList()
-        for i in range(nblocks):
-            channel = 32 * 2 ** (i + 1)
-            self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 3, channel, channel]))
-            self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor))
-            
-        self.fc2 = nn.Sequential(
-            nn.Linear(32 * 2 ** nblocks, 256),
-            nn.ReLU(),
-            nn.Linear(256, 64),
-            nn.ReLU(),
-            nn.Linear(64, n_c)
-        )
-        self.nblocks = nblocks
-    
-    def forward(self, x):
-        xyz = x[..., :3]
-        points = self.transformer1(xyz, self.fc1(x))[0]
-        for i in range(self.nblocks):
-            xyz, points = self.transition_downs[i](xyz, points)
-            points = self.transformers[i](xyz, points)[0]
-        res = self.fc2(points.mean(1))
-        return res
\ No newline at end of file
+import torch
+import torch.nn as nn
+from pointnet_util import PointNetFeaturePropagation, PointNetSetAbstraction
+from .transformer import TransformerBlock
+
+
+class TransitionDown(nn.Module):
+    def __init__(self, k, nneighbor, channels):
+        super().__init__()
+        self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True)
+        
+    def forward(self, xyz, points):
+        return self.sa(xyz, points)
+
+
+class TransitionUp(nn.Module):
+    def __init__(self, dim1, dim2, dim_out):
+        class SwapAxes(nn.Module):
+            def __init__(self):
+                super().__init__()
+            
+            def forward(self, x):
+                return x.transpose(1, 2)
+
+        super().__init__()
+        self.fc1 = nn.Sequential(
+            nn.Linear(dim1, dim_out),
+            SwapAxes(),
+            nn.BatchNorm1d(dim_out),  # TODO
+            SwapAxes(),
+            nn.ReLU(),
+        )
+        self.fc2 = nn.Sequential(
+            nn.Linear(dim2, dim_out),
+            SwapAxes(),
+            nn.BatchNorm1d(dim_out),  # TODO
+            SwapAxes(),
+            nn.ReLU(),
+        )
+        self.fp = PointNetFeaturePropagation(-1, [])
+    
+    def forward(self, xyz1, points1, xyz2, points2):
+        feats1 = self.fc1(points1)
+        feats2 = self.fc2(points2)
+        feats1 = self.fp(xyz2.transpose(1, 2), xyz1.transpose(1, 2), None, feats1.transpose(1, 2)).transpose(1, 2)
+        return feats1 + feats2
+        
+
+class Backbone(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim
+        self.fc1 = nn.Sequential(
+            nn.Linear(d_points, 32),
+            nn.ReLU(),
+            nn.Linear(32, 32)
+        )
+        self.transformer1 = TransformerBlock(32, cfg.model.transformer_dim, nneighbor)
+        self.transition_downs = nn.ModuleList()
+        self.transformers = nn.ModuleList()
+        for i in range(nblocks):
+            channel = 32 * 2 ** (i + 1)
+            self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 3, channel, channel]))
+            self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor))
+        self.nblocks = nblocks
+    
+    def forward(self, x):
+        xyz = x[..., :3]
+        points = self.transformer1(xyz, self.fc1(x))[0]
+
+        xyz_and_feats = [(xyz, points)]
+        for i in range(self.nblocks):
+            xyz, points = self.transition_downs[i](xyz, points)
+            points = self.transformers[i](xyz, points)[0]
+            xyz_and_feats.append((xyz, points))
+        return points, xyz_and_feats
+
+
+class PointTransformerCls(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.backbone = Backbone(cfg)
+        npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim
+        self.fc2 = nn.Sequential(
+            nn.Linear(32 * 2 ** nblocks, 256),
+            nn.ReLU(),
+            nn.Linear(256, 64),
+            nn.ReLU(),
+            nn.Linear(64, n_c)
+        )
+        self.nblocks = nblocks
+    
+    def forward(self, x):
+        points, _ = self.backbone(x)
+        res = self.fc2(points.mean(1))
+        return res
+
+
+class PointTransformerSeg(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.backbone = Backbone(cfg)
+        npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim
+        self.fc2 = nn.Sequential(
+            nn.Linear(32 * 2 ** nblocks, 512),
+            nn.ReLU(),
+            nn.Linear(512, 512),
+            nn.ReLU(),
+            nn.Linear(512, 32 * 2 ** nblocks)
+        )
+        self.transformer2 = TransformerBlock(32 * 2 ** nblocks, cfg.model.transformer_dim, nneighbor)
+        self.nblocks = nblocks
+        self.transition_ups = nn.ModuleList()
+        self.transformers = nn.ModuleList()
+        for i in reversed(range(nblocks)):
+            channel = 32 * 2 ** i
+            self.transition_ups.append(TransitionUp(channel * 2, channel, channel))
+            self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor))
+
+        self.fc3 = nn.Sequential(
+            nn.Linear(32, 64),
+            nn.ReLU(),
+            nn.Linear(64, 64),
+            nn.ReLU(),
+            nn.Linear(64, n_c)
+        )
+    
+    def forward(self, x):
+        points, xyz_and_feats = self.backbone(x)
+        xyz = xyz_and_feats[-1][0]
+        points = self.transformer2(xyz, self.fc2(points))[0]
+
+        for i in range(self.nblocks):
+            points = self.transition_ups[i](xyz, points, xyz_and_feats[- i - 2][0], xyz_and_feats[- i - 2][1])
+            xyz = xyz_and_feats[- i - 2][0]
+            points = self.transformers[i](xyz, points)[0]
+            
+        return self.fc3(points)
+
+
+    
\ No newline at end of file
diff --git a/models/Hengshuang/transformer.py b/models/Hengshuang/transformer.py
index 0b688c5f9af67e65c77b17cb3572284d64da901b..942fb5526a817ad1fcb0ce16e821ab13caf28318 100644
--- a/models/Hengshuang/transformer.py
+++ b/models/Hengshuang/transformer.py
@@ -1,45 +1,45 @@
-from pointnet_util import index_points, square_distance
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import numpy as np
-
-class TransformerBlock(nn.Module):
-    def __init__(self, d_points, d_model, k) -> None:
-        super().__init__()
-        self.fc1 = nn.Linear(d_points, d_model)
-        self.fc2 = nn.Linear(d_model, d_points)
-        self.fc_delta = nn.Sequential(
-            nn.Linear(3, d_model),
-            nn.ReLU(),
-            nn.Linear(d_model, d_model)
-        )
-        self.fc_gamma = nn.Sequential(
-            nn.Linear(d_model, d_model),
-            nn.ReLU(),
-            nn.Linear(d_model, d_model)
-        )
-        self.w_qs = nn.Linear(d_model, d_model, bias=False)
-        self.w_ks = nn.Linear(d_model, d_model, bias=False)
-        self.w_vs = nn.Linear(d_model, d_model, bias=False)
-        self.k = k
-        
-    # xyz: b x n x 3, features: b x n x f
-    def forward(self, xyz, features):
-        dists = square_distance(xyz, xyz)
-        knn_idx = dists.argsort()[:, :, :self.k]  # b x n x k
-        knn_xyz = index_points(xyz, knn_idx)
-        
-        pre = features
-        x = self.fc1(features)
-        q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx)
-
-        pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)  # b x n x k x f
-        
-        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
-        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f
-        
-        res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
-        res = self.fc2(res) + pre
-        return res, attn
+from pointnet_util import index_points, square_distance
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+class TransformerBlock(nn.Module):
+    def __init__(self, d_points, d_model, k) -> None:
+        super().__init__()
+        self.fc1 = nn.Linear(d_points, d_model)
+        self.fc2 = nn.Linear(d_model, d_points)
+        self.fc_delta = nn.Sequential(
+            nn.Linear(3, d_model),
+            nn.ReLU(),
+            nn.Linear(d_model, d_model)
+        )
+        self.fc_gamma = nn.Sequential(
+            nn.Linear(d_model, d_model),
+            nn.ReLU(),
+            nn.Linear(d_model, d_model)
+        )
+        self.w_qs = nn.Linear(d_model, d_model, bias=False)
+        self.w_ks = nn.Linear(d_model, d_model, bias=False)
+        self.w_vs = nn.Linear(d_model, d_model, bias=False)
+        self.k = k
+        
+    # xyz: b x n x 3, features: b x n x f
+    def forward(self, xyz, features):
+        dists = square_distance(xyz, xyz)
+        knn_idx = dists.argsort()[:, :, :self.k]  # b x n x k
+        knn_xyz = index_points(xyz, knn_idx)
+        
+        pre = features
+        x = self.fc1(features)
+        q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx)
+
+        pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)  # b x n x k x f
+        
+        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
+        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f
+        
+        res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
+        res = self.fc2(res) + pre
+        return res, attn
     
\ No newline at end of file
diff --git a/models/Menghao/model.py b/models/Menghao/model.py
index ed7481c44b154a7a5c504f459c2005afe7c1e637..f60f606f40722bbc63aaad67c507624fcef76226 100644
--- a/models/Menghao/model.py
+++ b/models/Menghao/model.py
@@ -1,159 +1,159 @@
-import torch
-import torch.nn as nn
-from pointnet_util import farthest_point_sample, index_points, square_distance
-
-
-def sample_and_group(npoint, nsample, xyz, points):
-    B, N, C = xyz.shape
-    S = npoint 
-    
-    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
-
-    new_xyz = index_points(xyz, fps_idx) 
-    new_points = index_points(points, fps_idx)
-
-    dists = square_distance(new_xyz, xyz)  # B x npoint x N
-    idx = dists.argsort()[:, :, :nsample]  # B x npoint x K
-
-    grouped_points = index_points(points, idx)
-    grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
-    new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
-    return new_xyz, new_points
-
-
-class Local_op(nn.Module):
-    def __init__(self, in_channels, out_channels):
-        super().__init__()
-        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
-        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False)
-        self.bn1 = nn.BatchNorm1d(out_channels)
-        self.bn2 = nn.BatchNorm1d(out_channels)
-        self.relu = nn.ReLU()
-
-    def forward(self, x):
-        b, n, s, d = x.size()  # torch.Size([32, 512, 32, 6]) 
-        x = x.permute(0, 1, 3, 2)
-        x = x.reshape(-1, d, s)
-        batch_size, _, N = x.size()
-        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
-        x = self.relu(self.bn2(self.conv2(x))) # B, D, N
-        x = torch.max(x, 2)[0]
-        x = x.view(batch_size, -1)
-        x = x.reshape(b, n, -1).permute(0, 2, 1)
-        return x
-
-
-class SA_Layer(nn.Module):
-    def __init__(self, channels):
-        super().__init__()
-        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
-        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
-        self.q_conv.weight = self.k_conv.weight 
-        self.v_conv = nn.Conv1d(channels, channels, 1)
-        self.trans_conv = nn.Conv1d(channels, channels, 1)
-        self.after_norm = nn.BatchNorm1d(channels)
-        self.act = nn.ReLU()
-        self.softmax = nn.Softmax(dim=-1)
-
-    def forward(self, x):
-        x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c 
-        x_k = self.k_conv(x)# b, c, n        
-        x_v = self.v_conv(x)
-        energy = x_q @ x_k # b, n, n 
-        attention = self.softmax(energy)
-        attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
-        x_r = x_v @ attention # b, c, n 
-        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
-        x = x + x_r
-        return x
-    
-
-class StackedAttention(nn.Module):
-    def __init__(self, channels=256):
-        super().__init__()
-        self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
-        self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
-
-        self.bn1 = nn.BatchNorm1d(channels)
-        self.bn2 = nn.BatchNorm1d(channels)
-
-        self.sa1 = SA_Layer(channels)
-        self.sa2 = SA_Layer(channels)
-        self.sa3 = SA_Layer(channels)
-        self.sa4 = SA_Layer(channels)
-
-        self.relu = nn.ReLU()
-        
-    def forward(self, x):
-        # 
-        # b, 3, npoint, nsample  
-        # conv2d 3 -> 128 channels 1, 1
-        # b * npoint, c, nsample 
-        # permute reshape
-        batch_size, _, N = x.size()
-
-        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
-        x = self.relu(self.bn2(self.conv2(x)))
-
-        x1 = self.sa1(x)
-        x2 = self.sa2(x1)
-        x3 = self.sa3(x2)
-        x4 = self.sa4(x3)
-        
-        x = torch.cat((x1, x2, x3, x4), dim=1)
-
-        return x
-
-
-class PointTransformer(nn.Module):
-    def __init__(self, cfg):
-        super().__init__()
-        output_channels = cfg.num_class
-        d_points = cfg.input_dim
-        self.conv1 = nn.Conv1d(d_points, 64, kernel_size=1, bias=False)
-        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
-        self.bn1 = nn.BatchNorm1d(64)
-        self.bn2 = nn.BatchNorm1d(64)
-        self.gather_local_0 = Local_op(in_channels=128, out_channels=128)
-        self.gather_local_1 = Local_op(in_channels=256, out_channels=256)
-        self.pt_last = StackedAttention()
-
-        self.relu = nn.ReLU()
-        self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
-                                   nn.BatchNorm1d(1024),
-                                   nn.LeakyReLU(negative_slope=0.2))
-
-        self.linear1 = nn.Linear(1024, 512, bias=False)
-        self.bn6 = nn.BatchNorm1d(512)
-        self.dp1 = nn.Dropout(p=0.5)
-        self.linear2 = nn.Linear(512, 256)
-        self.bn7 = nn.BatchNorm1d(256)
-        self.dp2 = nn.Dropout(p=0.5)
-        self.linear3 = nn.Linear(256, output_channels)
-
-    def forward(self, x):
-        xyz = x[..., :3]
-        x = x.permute(0, 2, 1)
-        batch_size, _, _ = x.size()
-        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
-        x = self.relu(self.bn2(self.conv2(x))) # B, D, N
-        x = x.permute(0, 2, 1)
-        new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x)         
-        feature_0 = self.gather_local_0(new_feature)
-        feature = feature_0.permute(0, 2, 1)
-        new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature) 
-        feature_1 = self.gather_local_1(new_feature)
-        
-        x = self.pt_last(feature_1)
-        x = torch.cat([x, feature_1], dim=1)
-        x = self.conv_fuse(x)
-        x = torch.max(x, 2)[0]
-        x = x.view(batch_size, -1)
-
-        x = self.relu(self.bn6(self.linear1(x)))
-        x = self.dp1(x)
-        x = self.relu(self.bn7(self.linear2(x)))
-        x = self.dp2(x)
-        x = self.linear3(x)
-
+import torch
+import torch.nn as nn
+from pointnet_util import farthest_point_sample, index_points, square_distance
+
+
+def sample_and_group(npoint, nsample, xyz, points):
+    B, N, C = xyz.shape
+    S = npoint 
+    
+    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
+
+    new_xyz = index_points(xyz, fps_idx) 
+    new_points = index_points(points, fps_idx)
+
+    dists = square_distance(new_xyz, xyz)  # B x npoint x N
+    idx = dists.argsort()[:, :, :nsample]  # B x npoint x K
+
+    grouped_points = index_points(points, idx)
+    grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
+    new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
+    return new_xyz, new_points
+
+
+class Local_op(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
+        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm1d(out_channels)
+        self.bn2 = nn.BatchNorm1d(out_channels)
+        self.relu = nn.ReLU()
+
+    def forward(self, x):
+        b, n, s, d = x.size()  # torch.Size([32, 512, 32, 6]) 
+        x = x.permute(0, 1, 3, 2)
+        x = x.reshape(-1, d, s)
+        batch_size, _, N = x.size()
+        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+        x = self.relu(self.bn2(self.conv2(x))) # B, D, N
+        x = torch.max(x, 2)[0]
+        x = x.view(batch_size, -1)
+        x = x.reshape(b, n, -1).permute(0, 2, 1)
+        return x
+
+
+class SA_Layer(nn.Module):
+    def __init__(self, channels):
+        super().__init__()
+        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
+        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
+        self.q_conv.weight = self.k_conv.weight 
+        self.v_conv = nn.Conv1d(channels, channels, 1)
+        self.trans_conv = nn.Conv1d(channels, channels, 1)
+        self.after_norm = nn.BatchNorm1d(channels)
+        self.act = nn.ReLU()
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x):
+        x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c 
+        x_k = self.k_conv(x)# b, c, n        
+        x_v = self.v_conv(x)
+        energy = x_q @ x_k # b, n, n 
+        attention = self.softmax(energy)
+        attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
+        x_r = x_v @ attention # b, c, n 
+        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
+        x = x + x_r
+        return x
+    
+
+class StackedAttention(nn.Module):
+    def __init__(self, channels=256):
+        super().__init__()
+        self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
+        self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
+
+        self.bn1 = nn.BatchNorm1d(channels)
+        self.bn2 = nn.BatchNorm1d(channels)
+
+        self.sa1 = SA_Layer(channels)
+        self.sa2 = SA_Layer(channels)
+        self.sa3 = SA_Layer(channels)
+        self.sa4 = SA_Layer(channels)
+
+        self.relu = nn.ReLU()
+        
+    def forward(self, x):
+        # 
+        # b, 3, npoint, nsample  
+        # conv2d 3 -> 128 channels 1, 1
+        # b * npoint, c, nsample 
+        # permute reshape
+        batch_size, _, N = x.size()
+
+        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+        x = self.relu(self.bn2(self.conv2(x)))
+
+        x1 = self.sa1(x)
+        x2 = self.sa2(x1)
+        x3 = self.sa3(x2)
+        x4 = self.sa4(x3)
+        
+        x = torch.cat((x1, x2, x3, x4), dim=1)
+
+        return x
+
+
+class PointTransformerCls(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        output_channels = cfg.num_class
+        d_points = cfg.input_dim
+        self.conv1 = nn.Conv1d(d_points, 64, kernel_size=1, bias=False)
+        self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm1d(64)
+        self.bn2 = nn.BatchNorm1d(64)
+        self.gather_local_0 = Local_op(in_channels=128, out_channels=128)
+        self.gather_local_1 = Local_op(in_channels=256, out_channels=256)
+        self.pt_last = StackedAttention()
+
+        self.relu = nn.ReLU()
+        self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
+                                   nn.BatchNorm1d(1024),
+                                   nn.LeakyReLU(negative_slope=0.2))
+
+        self.linear1 = nn.Linear(1024, 512, bias=False)
+        self.bn6 = nn.BatchNorm1d(512)
+        self.dp1 = nn.Dropout(p=0.5)
+        self.linear2 = nn.Linear(512, 256)
+        self.bn7 = nn.BatchNorm1d(256)
+        self.dp2 = nn.Dropout(p=0.5)
+        self.linear3 = nn.Linear(256, output_channels)
+
+    def forward(self, x):
+        xyz = x[..., :3]
+        x = x.permute(0, 2, 1)
+        batch_size, _, _ = x.size()
+        x = self.relu(self.bn1(self.conv1(x))) # B, D, N
+        x = self.relu(self.bn2(self.conv2(x))) # B, D, N
+        x = x.permute(0, 2, 1)
+        new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x)         
+        feature_0 = self.gather_local_0(new_feature)
+        feature = feature_0.permute(0, 2, 1)
+        new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature) 
+        feature_1 = self.gather_local_1(new_feature)
+        
+        x = self.pt_last(feature_1)
+        x = torch.cat([x, feature_1], dim=1)
+        x = self.conv_fuse(x)
+        x = torch.max(x, 2)[0]
+        x = x.view(batch_size, -1)
+
+        x = self.relu(self.bn6(self.linear1(x)))
+        x = self.dp1(x)
+        x = self.relu(self.bn7(self.linear2(x)))
+        x = self.dp2(x)
+        x = self.linear3(x)
+
         return x
\ No newline at end of file
diff --git a/models/Nico/model.py b/models/Nico/model.py
index 568fa6ecefc780b9074ec8f8a8c105c332a05e8e..e291554ca55829c170bcdb37959a874bc0ef7b54 100644
--- a/models/Nico/model.py
+++ b/models/Nico/model.py
@@ -82,7 +82,7 @@ class GlobalFeatureGeneration(nn.Module):
         return out, x
     
     
-class PointTransformer(nn.Module):
+class PointTransformerCls(nn.Module):
     def __init__(self, cfg):
         super().__init__()
         d_model_l, d_model_g, d_reduce, m, k, n_c, d_points, n_head \
diff --git a/train.py b/train_cls.py
similarity index 95%
rename from train.py
rename to train_cls.py
index 2b6c1cb702fa13402465740c589872cd62d7bf36..22e820089fde78567e7b7bc91ee0e70fb5e8f2da 100644
--- a/train.py
+++ b/train_cls.py
@@ -41,7 +41,7 @@ def test(model, loader, num_class=40):
     return instance_acc, class_acc
 
 
-@hydra.main(config_path='config', config_name='config')
+@hydra.main(config_path='config', config_name='cls')
 def main(args):
     omegaconf.OmegaConf.set_struct(args, False)
 
@@ -65,7 +65,7 @@ def main(args):
     args.input_dim = 6 if args.normal else 3
     shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
 
-    classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformer')(args).cuda()
+    classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerCls')(args).cuda()
     criterion = torch.nn.CrossEntropyLoss()
 
     try:
diff --git a/train_partseg.py b/train_partseg.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1e9eaccb53823e5893e0dc2c766a466007b04cd
--- /dev/null
+++ b/train_partseg.py
@@ -0,0 +1,242 @@
+"""
+Author: Benny
+Date: Nov 2019
+"""
+import argparse
+import os
+import torch
+import datetime
+import logging
+import sys
+import importlib
+import shutil
+import provider
+import numpy as np
+
+from pathlib import Path
+from tqdm import tqdm
+from dataset import PartNormalDataset
+import hydra
+import omegaconf
+
+
+seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
+               'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],
+               'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],
+               'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
+seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
+for cat in seg_classes.keys():
+    for label in seg_classes[cat]:
+        seg_label_to_cat[label] = cat
+
+
+def inplace_relu(m):
+    classname = m.__class__.__name__
+    if classname.find('ReLU') != -1:
+        m.inplace=True
+
+def to_categorical(y, num_classes):
+    """ 1-hot encodes a tensor """
+    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
+    if (y.is_cuda):
+        return new_y.cuda()
+    return new_y
+
+@hydra.main(config_path='config', config_name='partseg')
+def main(args):
+    omegaconf.OmegaConf.set_struct(args, False)
+
+    '''HYPER PARAMETER'''
+    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
+    logger = logging.getLogger(__name__)
+
+    print(args.pretty())
+
+    root = hydra.utils.to_absolute_path('data/shapenetcore_partanno_segmentation_benchmark_v0_normal/')
+
+    TRAIN_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='trainval', normal_channel=args.normal)
+    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
+    TEST_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='test', normal_channel=args.normal)
+    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10)
+
+    '''MODEL LOADING'''
+    args.input_dim = (6 if args.normal else 3) + 16
+    args.num_class = 50
+    num_category = 16
+    num_part = args.num_class
+    shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
+
+    classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerSeg')(args).cuda()
+    criterion = torch.nn.CrossEntropyLoss()
+
+    try:
+        checkpoint = torch.load('best_model.pth')
+        start_epoch = checkpoint['epoch']
+        classifier.load_state_dict(checkpoint['model_state_dict'])
+        logger.info('Use pretrain model')
+    except:
+        logger.info('No existing model, starting training from scratch...')
+        start_epoch = 0
+
+    if args.optimizer == 'Adam':
+        optimizer = torch.optim.Adam(
+            classifier.parameters(),
+            lr=args.learning_rate,
+            betas=(0.9, 0.999),
+            eps=1e-08,
+            weight_decay=args.weight_decay
+        )
+    else:
+        optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)
+
+    def bn_momentum_adjust(m, momentum):
+        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
+            m.momentum = momentum
+
+    LEARNING_RATE_CLIP = 1e-5
+    MOMENTUM_ORIGINAL = 0.1
+    MOMENTUM_DECCAY = 0.5
+    MOMENTUM_DECCAY_STEP = args.step_size
+
+    best_acc = 0
+    global_epoch = 0
+    best_class_avg_iou = 0
+    best_inctance_avg_iou = 0
+
+    for epoch in range(start_epoch, args.epoch):
+        mean_correct = []
+
+        logger.info('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
+        '''Adjust learning rate and BN momentum'''
+        lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
+        logger.info('Learning rate:%f' % lr)
+        for param_group in optimizer.param_groups:
+            param_group['lr'] = lr
+        momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
+        if momentum < 0.01:
+            momentum = 0.01
+        print('BN momentum updated to: %f' % momentum)
+        classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
+        classifier = classifier.train()
+
+        '''learning one epoch'''
+        for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
+            points = points.data.numpy()
+            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
+            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
+            points = torch.Tensor(points)
+
+            points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
+            optimizer.zero_grad()
+
+            seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
+            seg_pred = seg_pred.contiguous().view(-1, num_part)
+            target = target.view(-1, 1)[:, 0]
+            pred_choice = seg_pred.data.max(1)[1]
+
+            correct = pred_choice.eq(target.data).cpu().sum()
+            mean_correct.append(correct.item() / (args.batch_size * args.num_point))
+            loss = criterion(seg_pred, target)
+            loss.backward()
+            optimizer.step()
+
+        train_instance_acc = np.mean(mean_correct)
+        logger.info('Train accuracy is: %.5f' % train_instance_acc)
+
+        with torch.no_grad():
+            test_metrics = {}
+            total_correct = 0
+            total_seen = 0
+            total_seen_class = [0 for _ in range(num_part)]
+            total_correct_class = [0 for _ in range(num_part)]
+            shape_ious = {cat: [] for cat in seg_classes.keys()}
+            seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
+
+            for cat in seg_classes.keys():
+                for label in seg_classes[cat]:
+                    seg_label_to_cat[label] = cat
+
+            classifier = classifier.eval()
+
+            for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
+                cur_batch_size, NUM_POINT, _ = points.size()
+                points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
+                seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
+                cur_pred_val = seg_pred.cpu().data.numpy()
+                cur_pred_val_logits = cur_pred_val
+                cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
+                target = target.cpu().data.numpy()
+
+                for i in range(cur_batch_size):
+                    cat = seg_label_to_cat[target[i, 0]]
+                    logits = cur_pred_val_logits[i, :, :]
+                    cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
+
+                correct = np.sum(cur_pred_val == target)
+                total_correct += correct
+                total_seen += (cur_batch_size * NUM_POINT)
+
+                for l in range(num_part):
+                    total_seen_class[l] += np.sum(target == l)
+                    total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
+
+                for i in range(cur_batch_size):
+                    segp = cur_pred_val[i, :]
+                    segl = target[i, :]
+                    cat = seg_label_to_cat[segl[0]]
+                    part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
+                    for l in seg_classes[cat]:
+                        if (np.sum(segl == l) == 0) and (
+                                np.sum(segp == l) == 0):  # part is not present, no prediction as well
+                            part_ious[l - seg_classes[cat][0]] = 1.0
+                        else:
+                            part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
+                                np.sum((segl == l) | (segp == l)))
+                    shape_ious[cat].append(np.mean(part_ious))
+
+            all_shape_ious = []
+            for cat in shape_ious.keys():
+                for iou in shape_ious[cat]:
+                    all_shape_ious.append(iou)
+                shape_ious[cat] = np.mean(shape_ious[cat])
+            mean_shape_ious = np.mean(list(shape_ious.values()))
+            test_metrics['accuracy'] = total_correct / float(total_seen)
+            test_metrics['class_avg_accuracy'] = np.mean(
+                np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float))
+            for cat in sorted(shape_ious.keys()):
+                logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
+            test_metrics['class_avg_iou'] = mean_shape_ious
+            test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
+
+        logger.info('Epoch %d test Accuracy: %f  Class avg mIOU: %f   Inctance avg mIOU: %f' % (
+            epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
+        if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
+            logger.info('Save model...')
+            savepath = 'best_model.pth'
+            logger.info('Saving at %s' % savepath)
+            state = {
+                'epoch': epoch,
+                'train_acc': train_instance_acc,
+                'test_acc': test_metrics['accuracy'],
+                'class_avg_iou': test_metrics['class_avg_iou'],
+                'inctance_avg_iou': test_metrics['inctance_avg_iou'],
+                'model_state_dict': classifier.state_dict(),
+                'optimizer_state_dict': optimizer.state_dict(),
+            }
+            torch.save(state, savepath)
+            logger.info('Saving model....')
+
+        if test_metrics['accuracy'] > best_acc:
+            best_acc = test_metrics['accuracy']
+        if test_metrics['class_avg_iou'] > best_class_avg_iou:
+            best_class_avg_iou = test_metrics['class_avg_iou']
+        if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
+            best_inctance_avg_iou = test_metrics['inctance_avg_iou']
+        logger.info('Best accuracy is: %.5f' % best_acc)
+        logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
+        logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
+        global_epoch += 1
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file