diff --git a/README.md b/README.md
index e39eadc67890857f31b24547a9d23bbcd669cc2f..69af2963ad4866d999e9a95330413a81616d8169 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
 # Pytorch Implementation of Various Point Transformers
 
-Recently, various methods applied transformers to point clouds: [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688), [Point Transformer (Nico Engel et al.)](https://arxiv.org/abs/2011.00931), [Point Transformer (Hengshuang Zhao et al.)](https://arxiv.org/abs/2012.09164). This repo is a pytorch implementation for these methods and aims to compare them under a fair setting. Currently, Point Transformer (Nico Engel et al.) is implemented.
+Recently, various methods applied transformers to point clouds: [PCT: Point Cloud Transformer (Meng-Hao Guo et al.)](https://arxiv.org/abs/2012.09688), [Point Transformer (Nico Engel et al.)](https://arxiv.org/abs/2011.00931), [Point Transformer (Hengshuang Zhao et al.)](https://arxiv.org/abs/2012.09164). This repo is a pytorch implementation for these methods and aims to compare them under a fair setting. Currently, Point Transformer (Nico Engel et al.) and Point Transformer (Hengshuang Zhao et al.) are implemented.
 
 
 ## Classification
@@ -8,6 +8,7 @@ Recently, various methods applied transformers to point clouds: [PCT: Point Clou
 Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `modelnet40_normal_resampled`.
 
 ### Run
+Change which method to use in `config/config.yaml` and run
 ```
 python train.py
 ```
@@ -17,5 +18,4 @@ TBA
 ### Miscellaneous
 Some code and training settings are borrowed from https://github.com/yanx27/Pointnet_Pointnet2_pytorch.
 ## TODOs
-- [ ] implement Point Transformer (Hengshuang Zhao et al.)
 - [ ] implement PCT: Point Cloud Transformer (Meng-Hao Guo et al.)
\ No newline at end of file
diff --git a/config.yaml b/config/config.yaml
similarity index 59%
rename from config.yaml
rename to config/config.yaml
index f6e59a8aa7832d6fc9726e1e46701241473701c4..2c933de487aa4d434add44c92dce10e4c2eafe62 100644
--- a/config.yaml
+++ b/config/config.yaml
@@ -1,4 +1,4 @@
-batch_size: 24
+batch_size: 16
 epoch: 200
 learning_rate: 1e-3
 gpu: 1
@@ -7,6 +7,9 @@ optimizer: Adam
 weight_decay: 1e-4
 normal: True
 
+defaults:
+  - model: Nico
+
 hydra:
   run:
-    dir: outputs
\ No newline at end of file
+    dir: outputs/${model}
\ No newline at end of file
diff --git a/config/model/Hengshuang.yaml b/config/model/Hengshuang.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b6a9e188d1a970176b35ba7088565fc164f01908
--- /dev/null
+++ b/config/model/Hengshuang.yaml
@@ -0,0 +1,5 @@
+# @package _group_
+nneighbor: 16
+nblocks: 4
+transformer_dim: 512
+name: Hengshuang
\ No newline at end of file
diff --git a/config/model/Nico.yaml b/config/model/Nico.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eaed132363eea5516edda5063df6eb30b107f5ea
--- /dev/null
+++ b/config/model/Nico.yaml
@@ -0,0 +1,9 @@
+# @package _group_
+n_head: 8
+m: 4
+k: 64
+global_k: 128
+global_dim: 512
+local_dim: 256
+reduce_dim: 64
+name: Nico
\ No newline at end of file
diff --git a/models/Hengshuang/model.py b/models/Hengshuang/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b97e807dac412916f2959a9aabece44b2f9ce7d
--- /dev/null
+++ b/models/Hengshuang/model.py
@@ -0,0 +1,49 @@
+import torch
+import torch.nn as nn
+from pointnet_util import PointNetSetAbstraction
+from .transformer import TransformerBlock
+
+
+class TransitionDown(nn.Module):
+    def __init__(self, k, nneighbor, channels) -> None:
+        super().__init__()
+        self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True)
+        
+    def forward(self, xyz, points):
+        return self.sa(xyz, points)
+    
+    
+class PointTransformer(nn.Module):
+    def __init__(self, cfg) -> None:
+        super().__init__()
+        npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim
+        self.fc1 = nn.Sequential(
+            nn.Linear(d_points, 32),
+            nn.ReLU(),
+            nn.Linear(32, 32)
+        )
+        self.transformer1 = TransformerBlock(32, cfg.model.transformer_dim, nneighbor)
+        self.transition_downs = nn.ModuleList()
+        self.transformers = nn.ModuleList()
+        for i in range(nblocks):
+            channel = 32 * 2 ** (i + 1)
+            self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 3, channel, channel]))
+            self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor))
+            
+        self.fc2 = nn.Sequential(
+            nn.Linear(32 * 2 ** nblocks, 256),
+            nn.ReLU(),
+            nn.Linear(256, 64),
+            nn.ReLU(),
+            nn.Linear(64, n_c)
+        )
+        self.nblocks = nblocks
+    
+    def forward(self, x):
+        xyz = x[..., :3]
+        points = self.transformer1(xyz, self.fc1(x))[0]
+        for i in range(self.nblocks):
+            xyz, points = self.transition_downs[i](xyz, points)
+            points = self.transformers[i](xyz, points)[0]
+        res = self.fc2(points.mean(1))
+        return res
\ No newline at end of file
diff --git a/models/Hengshuang/transformer.py b/models/Hengshuang/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b688c5f9af67e65c77b17cb3572284d64da901b
--- /dev/null
+++ b/models/Hengshuang/transformer.py
@@ -0,0 +1,45 @@
+from pointnet_util import index_points, square_distance
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+class TransformerBlock(nn.Module):
+    def __init__(self, d_points, d_model, k) -> None:
+        super().__init__()
+        self.fc1 = nn.Linear(d_points, d_model)
+        self.fc2 = nn.Linear(d_model, d_points)
+        self.fc_delta = nn.Sequential(
+            nn.Linear(3, d_model),
+            nn.ReLU(),
+            nn.Linear(d_model, d_model)
+        )
+        self.fc_gamma = nn.Sequential(
+            nn.Linear(d_model, d_model),
+            nn.ReLU(),
+            nn.Linear(d_model, d_model)
+        )
+        self.w_qs = nn.Linear(d_model, d_model, bias=False)
+        self.w_ks = nn.Linear(d_model, d_model, bias=False)
+        self.w_vs = nn.Linear(d_model, d_model, bias=False)
+        self.k = k
+        
+    # xyz: b x n x 3, features: b x n x f
+    def forward(self, xyz, features):
+        dists = square_distance(xyz, xyz)
+        knn_idx = dists.argsort()[:, :, :self.k]  # b x n x k
+        knn_xyz = index_points(xyz, knn_idx)
+        
+        pre = features
+        x = self.fc1(features)
+        q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx)
+
+        pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)  # b x n x k x f
+        
+        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
+        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f
+        
+        res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
+        res = self.fc2(res) + pre
+        return res, attn
+    
\ No newline at end of file
diff --git a/model.py b/models/Nico/model.py
similarity index 60%
rename from model.py
rename to models/Nico/model.py
index 5ec388f5f50169ab5375accddd39541144828834..568fa6ecefc780b9074ec8f8a8c105c332a05e8e 100644
--- a/model.py
+++ b/models/Nico/model.py
@@ -1,22 +1,34 @@
 import torch
 import torch.nn as nn
 from pointnet_util import PointNetSetAbstractionMsg
-from transformer import MultiHeadAttention
+from .transformer import MultiHeadAttention
 
 
 class SortNet(nn.Module):
     def __init__(self, d_model, d_points=6, k=64):
         super().__init__()
-        self.fc = nn.Linear(d_model, 1)
+        self.fc = nn.Sequential(
+            nn.Linear(d_model, 256),
+            nn.ReLU(),
+            nn.Linear(256, 64),
+            nn.ReLU(),
+            nn.Linear(64, 1)
+        )
         self.sa = PointNetSetAbstractionMsg(k, [0.1, 0.2, 0.4], [16, 32, 128], d_model, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
-        self.fc_agg = nn.Linear(64 + 128 + 128, d_model - 1 - d_points)
+        self.fc_agg = nn.Sequential(
+            nn.Linear(64 + 128 + 128, 256),
+            nn.ReLU(),
+            nn.Linear(256, 256),
+            nn.ReLU(),
+            nn.Linear(256, d_model - 1 - d_points),
+        )
         self.k = k
         self.d_points = d_points
         
     def forward(self, points, features):
         score = self.fc(features)
         topk_idx = torch.topk(score[..., 0], self.k, 1)[1]
-        features_abs = self.sa(points[..., :3].transpose(1, 2), features.transpose(1, 2), topk_idx)[1].transpose(1, 2)
+        features_abs = self.sa(points[..., :3], features, topk_idx)[1]
         res = torch.cat((self.fc_agg(features_abs),
                          torch.gather(score, 1, topk_idx[..., None].expand(-1, -1, score.size(-1))),
                          torch.gather(points, 1, topk_idx[..., None].expand(-1, -1, points.size(-1)))), -1)
@@ -26,7 +38,13 @@ class SortNet(nn.Module):
 class LocalFeatureGeneration(nn.Module):
     def __init__(self, d_model, m, k, d_points=6, n_head=4):
         super().__init__()
-        self.fc = nn.Linear(d_points, d_model)
+        self.fc = nn.Sequential(
+            nn.Linear(d_points, 64),
+            nn.ReLU(),
+            nn.Linear(64, 256),
+            nn.ReLU(),
+            nn.Linear(256, d_model)
+        ) 
         self.sortnets = nn.ModuleList([SortNet(d_model, k=k)] * m)
         self.att = MultiHeadAttention(n_head, d_model, d_model, d_model // n_head, d_model // n_head)
         
@@ -40,25 +58,45 @@ class LocalFeatureGeneration(nn.Module):
 class GlobalFeatureGeneration(nn.Module):
     def __init__(self, d_model, k, d_points=6, n_head=4):
         super().__init__()
-        self.fc = nn.Linear(d_points, d_model)
+        self.fc = nn.Sequential(
+            nn.Linear(d_points, 64),
+            nn.ReLU(),
+            nn.Linear(64, 256),
+            nn.ReLU(),
+            nn.Linear(256, d_model)
+        ) 
         self.sa = PointNetSetAbstractionMsg(k, [0.1, 0.2, 0.4], [16, 32, 128], d_model, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
         self.att = MultiHeadAttention(n_head, d_model, d_model, d_model // n_head, d_model // n_head)
-        self.fc_agg = nn.Linear(64 + 128 + 128, d_model)
+        self.fc_agg = nn.Sequential(
+            nn.Linear(64 + 128 + 128, 256),
+            nn.ReLU(),
+            nn.Linear(256, 256),
+            nn.ReLU(),
+            nn.Linear(256, d_model),
+        )
         
     def forward(self, points):
         x = self.fc(points)
         x, _ = self.att(x, x, x)
-        out = self.fc_agg(self.sa(points[..., :3].transpose(1, 2), x.transpose(1, 2))[1].transpose(1, 2))
+        out = self.fc_agg(self.sa(points[..., :3], x)[1])
         return out, x
     
     
 class PointTransformer(nn.Module):
-    def __init__(self, d_model_l, d_model_g, d_reduce, m, k, n_c, d_points=6, n_head=4):
+    def __init__(self, cfg):
         super().__init__()
+        d_model_l, d_model_g, d_reduce, m, k, n_c, d_points, n_head \
+            = cfg.model.global_dim, cfg.model.local_dim, cfg.model.reduce_dim, cfg.model.m, cfg.model.k, cfg.num_class, cfg.input_dim, cfg.model.n_head
         self.lfg = LocalFeatureGeneration(d_model=d_model_l, m=m, k=k, d_points=d_points)
-        self.gfg = GlobalFeatureGeneration(d_model=d_model_g, k=128, d_points=d_points)
+        self.gfg = GlobalFeatureGeneration(d_model=d_model_g, k=cfg.model.global_k, d_points=d_points)
         self.lg_att = MultiHeadAttention(n_head, d_model_l, d_model_g, d_model_l // n_head, d_model_l // n_head)
-        self.fc = nn.Linear(d_model_l, d_reduce)
+        self.fc = nn.Sequential(
+            nn.Linear(d_model_l, 256),
+            nn.ReLU(),
+            nn.Linear(256, 256),
+            nn.ReLU(),
+            nn.Linear(256, d_reduce),
+        )
         self.fc_cls = nn.Sequential(
             nn.Linear(k * m * d_reduce, 1024),
             nn.ReLU(),
diff --git a/transformer.py b/models/Nico/transformer.py
similarity index 100%
rename from transformer.py
rename to models/Nico/transformer.py
diff --git a/pointnet_util.py b/pointnet_util.py
index 58aef06fa260c54ad0f0b8dc7ebcb8fced964e87..8e042a973cf968540bec6f73fed09503617db819 100644
--- a/pointnet_util.py
+++ b/pointnet_util.py
@@ -96,7 +96,7 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
     return group_idx
 
 
-def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
+def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False):
     """
     Input:
         npoint:
@@ -110,11 +110,15 @@ def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
     """
     B, N, C = xyz.shape
     S = npoint
-    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
+    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
     torch.cuda.empty_cache()
     new_xyz = index_points(xyz, fps_idx)
     torch.cuda.empty_cache()
-    idx = query_ball_point(radius, nsample, xyz, new_xyz)
+    if knn:
+        dists = square_distance(new_xyz, xyz)  # B x npoint x N
+        idx = dists.argsort()[:, :, :nsample]  # B x npoint x K
+    else:
+        idx = query_ball_point(radius, nsample, xyz, new_xyz)
     torch.cuda.empty_cache()
     grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
     torch.cuda.empty_cache()
@@ -153,11 +157,12 @@ def sample_and_group_all(xyz, points):
 
 
 class PointNetSetAbstraction(nn.Module):
-    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
+    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False):
         super(PointNetSetAbstraction, self).__init__()
         self.npoint = npoint
         self.radius = radius
         self.nsample = nsample
+        self.knn = knn
         self.mlp_convs = nn.ModuleList()
         self.mlp_bns = nn.ModuleList()
         last_channel = in_channel
@@ -170,20 +175,16 @@ class PointNetSetAbstraction(nn.Module):
     def forward(self, xyz, points):
         """
         Input:
-            xyz: input points position data, [B, C, N]
-            points: input points data, [B, D, N]
+            xyz: input points position data, [B, N, C]
+            points: input points data, [B, N, C]
         Return:
-            new_xyz: sampled points position data, [B, C, S]
-            new_points_concat: sample points feature data, [B, D', S]
+            new_xyz: sampled points position data, [B, S, C]
+            new_points_concat: sample points feature data, [B, S, D']
         """
-        xyz = xyz.permute(0, 2, 1)
-        if points is not None:
-            points = points.permute(0, 2, 1)
-
         if self.group_all:
             new_xyz, new_points = sample_and_group_all(xyz, points)
         else:
-            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
+            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn)
         # new_xyz: sampled points position data, [B, npoint, C]
         # new_points: sampled points data, [B, npoint, nsample, C+D]
         new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
@@ -191,17 +192,17 @@ class PointNetSetAbstraction(nn.Module):
             bn = self.mlp_bns[i]
             new_points =  F.relu(bn(conv(new_points)))
 
-        new_points = torch.max(new_points, 2)[0]
-        new_xyz = new_xyz.permute(0, 2, 1)
+        new_points = torch.max(new_points, 2)[0].transpose(1, 2)
         return new_xyz, new_points
 
 
 class PointNetSetAbstractionMsg(nn.Module):
-    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
+    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False):
         super(PointNetSetAbstractionMsg, self).__init__()
         self.npoint = npoint
         self.radius_list = radius_list
         self.nsample_list = nsample_list
+        self.knn = knn
         self.conv_blocks = nn.ModuleList()
         self.bn_blocks = nn.ModuleList()
         for i in range(len(mlp_list)):
@@ -224,9 +225,6 @@ class PointNetSetAbstractionMsg(nn.Module):
             new_xyz: sampled points position data, [B, C, S]
             new_points_concat: sample points feature data, [B, D', S]
         """
-        xyz = xyz.permute(0, 2, 1)
-        if points is not None:
-            points = points.permute(0, 2, 1)
 
         B, N, C = xyz.shape
         S = self.npoint
@@ -234,7 +232,11 @@ class PointNetSetAbstractionMsg(nn.Module):
         new_points_list = []
         for i, radius in enumerate(self.radius_list):
             K = self.nsample_list[i]
-            group_idx = query_ball_point(radius, K, xyz, new_xyz)
+            if self.knn:
+                dists = square_distance(new_xyz, xyz)  # B x npoint x N
+                group_idx = dists.argsort()[:, :, :K]  # B x npoint x K
+            else:
+                group_idx = query_ball_point(radius, K, xyz, new_xyz)
             grouped_xyz = index_points(xyz, group_idx)
             grouped_xyz -= new_xyz.view(B, S, 1, C)
             if points is not None:
@@ -251,11 +253,11 @@ class PointNetSetAbstractionMsg(nn.Module):
             new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]
             new_points_list.append(new_points)
 
-        new_xyz = new_xyz.permute(0, 2, 1)
-        new_points_concat = torch.cat(new_points_list, dim=1)
+        new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2)
         return new_xyz, new_points_concat
 
 
+# NoteL this function swaps N and C
 class PointNetFeaturePropagation(nn.Module):
     def __init__(self, in_channel, mlp):
         super(PointNetFeaturePropagation, self).__init__()
diff --git a/train.py b/train.py
index ea8629208786c44fa8169fb9c05438ce8ded1414..2b6c1cb702fa13402465740c589872cd62d7bf36 100644
--- a/train.py
+++ b/train.py
@@ -16,8 +16,7 @@ import provider
 import importlib
 import shutil
 import hydra
-from model import PointTransformer
-
+import omegaconf
 
 
 def test(model, loader, num_class=40):
@@ -41,8 +40,10 @@ def test(model, loader, num_class=40):
     instance_acc = np.mean(mean_correct)
     return instance_acc, class_acc
 
-@hydra.main(config_name='config')
+
+@hydra.main(config_path='config', config_name='config')
 def main(args):
+    omegaconf.OmegaConf.set_struct(args, False)
 
     '''HYPER PARAMETER'''
     os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
@@ -60,10 +61,11 @@ def main(args):
     testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)
 
     '''MODEL LOADING'''
-    num_class = 40
-    shutil.copy(hydra.utils.to_absolute_path('model.py'), '.')
+    args.num_class = 40
+    args.input_dim = 6 if args.normal else 3
+    shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
 
-    classifier = PointTransformer(512, 256, 64, m=4, k=64, n_c=40, d_points=6).cuda()
+    classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformer')(args).cuda()
     criterion = torch.nn.CrossEntropyLoss()
 
     try:
@@ -87,7 +89,7 @@ def main(args):
     else:
         optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
 
-    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
+    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.3)
     global_epoch = 0
     global_step = 0
     best_instance_acc = 0.0