diff --git a/check_stuff.ipynb b/check_stuff.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..b1b0ca00de7493e9bcf67ea0ed2414d11cf0ab68
--- /dev/null
+++ b/check_stuff.ipynb
@@ -0,0 +1,218 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from dataset import PartNormalDataset\n",
+    "from torch.utils.data import DataLoader\n",
+    "\n",
+    "root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/'\n",
+    "\n",
+    "train_dataset = PartNormalDataset(root=root, npoints=4, split='trainval', normal_channel=True)\n",
+    "\n",
+    "a = train_dataset.__getitem__(0)\n",
+    "\n",
+    "print(a)\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from pointnet_util import index_points, square_distance\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import numpy as np\n",
+    "\n",
+    "class TransformerBlock(nn.Module):\n",
+    "    def __init__(self, d_points=32, d_model=512, k=16) -> None:\n",
+    "        # d_points: number of points\n",
+    "        # d_model: number of features\n",
+    "        # k: number of neighbors\n",
+    "\n",
+    "        super().__init__()\n",
+    "        self.fc1 = nn.Linear(d_points, d_model)\n",
+    "        self.fc2 = nn.Linear(d_model, d_points)\n",
+    "        self.fc_delta = nn.Sequential(\n",
+    "            nn.Linear(3, d_model),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(d_model, d_model)\n",
+    "        )\n",
+    "        self.fc_gamma = nn.Sequential(\n",
+    "            nn.Linear(d_model, d_model),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(d_model, d_model)\n",
+    "        )\n",
+    "        self.w_qs = nn.Linear(d_model, d_model, bias=False)\n",
+    "        self.w_ks = nn.Linear(d_model, d_model, bias=False)\n",
+    "        self.w_vs = nn.Linear(d_model, d_model, bias=False)\n",
+    "        self.k = k\n",
+    "        \n",
+    "    # xyz: b x n x 3, features: b x n x f\n",
+    "    def forward(self, xyz, features):\n",
+    "        dists = square_distance(xyz, xyz)\n",
+    "        knn_idx = dists.argsort()[:, :, :self.k]  # b x n x k\n",
+    "        knn_xyz = index_points(xyz, knn_idx)\n",
+    "        \n",
+    "        pre = features\n",
+    "        x = self.fc1(features)\n",
+    "        q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx)\n",
+    "\n",
+    "        pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)  # b x n x k x f\n",
+    "        \n",
+    "        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)\n",
+    "        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f\n",
+    "        \n",
+    "        res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)\n",
+    "        res = self.fc2(res) + pre\n",
+    "        return res, attn\n",
+    "\n",
+    "\n",
+    "# get the model \n",
+    "# Sample input data\n",
+    "xyz = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]])\n",
+    "features = torch.tensor([[[0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9], [1.0, 1.1, 1.2, 1.3]]]).view(1, 4, 3)\n",
+    "\n",
+    "# Create a TransformerBlock instance\n",
+    "transformer_block = TransformerBlock(d_points=3, d_model=4, k=2)\n",
+    "\n",
+    "# Call the forward method\n",
+    "output, attn = transformer_block(xyz, features)\n",
+    "\n",
+    "# Print the output shape and attention shape\n",
+    "print(\"Output shape:\", output.shape)\n",
+    "print(\"Attention shape:\", attn.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n",
+      "torch.Size([1, 1024, 6]) torch.Size([1, 1024])\n"
+     ]
+    }
+   ],
+   "source": [
+    "from nibio_transformer_semantic.dataset import Dataset\n",
+    "import torch\n",
+    "\n",
+    "dataset = Dataset(\n",
+    "    root='data/forest_txt/validation_txt/', \n",
+    "    npoints=1024, \n",
+    "    normal_channel=True\n",
+    "    )\n",
+    "\n",
+    "trainDataLoader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=10, drop_last=True)\n",
+    "\n",
+    "# run 4 batches\n",
+    "for i, data in enumerate(trainDataLoader):\n",
+    "    if i == 24:\n",
+    "        break\n",
+    "    print(data[0].shape, data[1].shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "def to_categorical(y, num_classes):\n",
+    "    \"\"\" 1-hot encodes a tensor \"\"\"\n",
+    "    print(\"num_classes: \", num_classes)\n",
+    "    print(\"y: \", y)\n",
+    "    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]\n",
+    "    if (y.is_cuda):\n",
+    "        return new_y.cuda()\n",
+    "    return new_y\n",
+    "\n",
+    "y = torch.tensor([0, 1, 2, 3, 4])\n",
+    "\n",
+    "print(to_categorical(y, 10))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenetcore_partanno_segmentation_benchmark_v0_normal/02691156/14cd2f1de7f68bf3ab550998f901c8e1.txt'"
+      ]
+     },
+     "execution_count": 17,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from dataset import PartNormalDataset\n",
+    "import torch\n",
+    "root = '/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenetcore_partanno_segmentation_benchmark_v0_normal'\n",
+    "\n",
+    "TEST_DATASET = PartNormalDataset(root=root, npoints=2, split='test', normal_channel=True)\n",
+    "testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=1, shuffle=False, num_workers=1)\n",
+    "\n",
+    "TEST_DATASET.datapath[3][1]\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.10"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/dataset.py b/dataset.py
index 41d0d61bd3951b8beedd7dd208dea7b3b5a7d2e3..8f3ac5b010927f4f73056ad9b1727ec5e86ab2c9 100644
--- a/dataset.py
+++ b/dataset.py
@@ -163,9 +163,6 @@ class PartNormalDataset(Dataset):
         point_set = point_set[choice, :]
         seg = seg[choice]
 
-        print("cls", cls)
-        print("seg", seg) 
-
         return point_set, cls, seg
 
     def __len__(self):
diff --git a/las2text_mapper.py b/las2text_mapper.py
index d70c6fd0ab279362743e27de60e6c59909b47cc8..d4b354ff7230247b91536d22f5a3946eaf2145d7 100644
--- a/las2text_mapper.py
+++ b/las2text_mapper.py
@@ -47,6 +47,9 @@ class Las2TextMapper:
         # put all together to pandas dataframe
         points = pd.DataFrame(points, columns=['x', 'y', 'z', 'red', 'green', 'blue', 'label', 'treeID'])
 
+        # reduce label to 0, 1, 2, 3
+        points['label'] = points['label'] - 1
+
         return points
     
     def process_single_file(self, filepath):
@@ -75,9 +78,11 @@ class Las2TextMapper:
         """
         # read all las files in the folder data_dir using glob
         list_of_files = glob.glob(self.data_dir + "/*.las", recursive=False)
-
-        Parallel(n_jobs=8)(delayed(self.process_single_file)(filepath) for filepath in list_of_files)
-
+       
+        Parallel(n_jobs=-1)(
+            delayed(self.process_single_file)(filepath) for filepath in list_of_files
+            )
+        
         if self.verbose:
             print("Done processing the folder")
 
diff --git a/models/Hengshuang/__init__.py b/models/Hengshuang/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/train_partseg.py b/train_partseg.py
index 600499185749a4d7a182a6eaeade6661932b795f..996b9ecd64b643cc3165ecc053e952f74c3f6b0c 100644
--- a/train_partseg.py
+++ b/train_partseg.py
@@ -137,6 +137,14 @@ def main(args):
             points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
             optimizer.zero_grad()
 
+            # print("target shape ..: ", target.shape)
+            # print("points shape ..: ", points.shape)
+            # print("label shape ...: ", label.shape)
+            # print(" points.shape[1] : ", points.shape[1])
+            # print("to_categorical(label, num_category): ", to_categorical(label, num_category).shape)
+
+            # print("to_categorical(label, num_category).repeat(1, points.shape[1], 1): ", to_categorical(label, num_category).repeat(1, points.shape[1], 1).shape)
+          
             seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
             seg_pred = seg_pred.contiguous().view(-1, num_part)
             target = target.view(-1, 1)[:, 0]
diff --git a/train_partseg_forest.py b/train_partseg_forest.py
new file mode 100644
index 0000000000000000000000000000000000000000..75c961a2597873972e2dde787753f43e426d0b8d
--- /dev/null
+++ b/train_partseg_forest.py
@@ -0,0 +1,263 @@
+"""
+Author: Benny
+Date: Nov 2019
+"""
+import argparse
+import os
+import torch
+import datetime
+import logging
+import sys
+import importlib
+import shutil
+import provider
+import numpy as np
+
+from pathlib import Path
+from tqdm import tqdm
+from dataset import PartNormalDataset
+from nibio_transformer_semantic.dataset import Dataset
+import hydra
+import omegaconf
+
+
+seg_classes = {'tree': [0,1,2,4]}
+seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
+for cat in seg_classes.keys():
+    for label in seg_classes[cat]:
+        seg_label_to_cat[label] = cat
+
+
+def inplace_relu(m):
+    classname = m.__class__.__name__
+    if classname.find('ReLU') != -1:
+        m.inplace=True
+
+def to_categorical(y, num_classes):
+    """ 1-hot encodes a tensor """
+    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
+    if (y.is_cuda):
+        return new_y.cuda()
+    return new_y
+
+@hydra.main(config_path='config', config_name='partseg')
+def main(args):
+    omegaconf.OmegaConf.set_struct(args, False)
+
+    '''HYPER PARAMETER'''
+    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
+
+    # print('GPU available: {}'.format(torch.cuda.is_available()))
+    logger = logging.getLogger(__name__)
+
+    # print(args.pretty())
+
+    # use pretty print to print the config
+    
+
+    train_dataset = hydra.utils.to_absolute_path('data/forest_txt/validation_txt/')
+    test_dataset = hydra.utils.to_absolute_path('data/forest_txt/validation_txt/')
+
+    TRAIN_DATASET = Dataset(root=train_dataset, npoints=args.num_point, normal_channel=args.normal)
+    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
+    TEST_DATASET = Dataset(root=test_dataset, npoints=args.num_point, normal_channel=args.normal)
+    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10)
+
+    '''MODEL LOADING'''
+    args.input_dim = (6 if args.normal else 3) + 16
+    args.num_class = 4
+    num_category = 1
+    num_part = args.num_class
+    shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
+
+    # print if gpu is available
+    logger.info('GPU available: {}'.format(torch.cuda.is_available()))
+
+    classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerSeg')(args).cuda()
+    criterion = torch.nn.CrossEntropyLoss()
+
+    try:
+        checkpoint = torch.load('best_model.pth')
+        start_epoch = checkpoint['epoch']
+        classifier.load_state_dict(checkpoint['model_state_dict'])
+        logger.info('Use pretrain model')
+    except:
+        logger.info('No existing model, starting training from scratch...')
+        start_epoch = 0
+
+    if args.optimizer == 'Adam':
+        optimizer = torch.optim.Adam(
+            classifier.parameters(),
+            lr=args.learning_rate,
+            betas=(0.9, 0.999),
+            eps=1e-08,
+            weight_decay=args.weight_decay
+        )
+    else:
+        optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)
+
+    def bn_momentum_adjust(m, momentum):
+        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
+            m.momentum = momentum
+
+    LEARNING_RATE_CLIP = 1e-5
+    MOMENTUM_ORIGINAL = 0.1
+    MOMENTUM_DECCAY = 0.5
+    MOMENTUM_DECCAY_STEP = args.step_size
+
+    best_acc = 0
+    global_epoch = 0
+    best_class_avg_iou = 0
+    best_inctance_avg_iou = 0
+
+    for epoch in range(start_epoch, args.epoch):
+        mean_correct = []
+
+        logger.info('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
+        '''Adjust learning rate and BN momentum'''
+        lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
+        logger.info('Learning rate:%f' % lr)
+        for param_group in optimizer.param_groups:
+            param_group['lr'] = lr
+        momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
+        if momentum < 0.01:
+            momentum = 0.01
+        print('BN momentum updated to: %f' % momentum)
+        classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
+        classifier = classifier.train()
+
+        '''learning one epoch'''
+        for i, (points, label) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
+            points = points.data.numpy()
+            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
+            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
+            points = torch.Tensor(points)
+
+            points, label = points.float().cuda(), label.long().cuda()
+            optimizer.zero_grad()
+
+            # print("points shape ..: ", points.shape)
+            # print("label shape ...: ", label.shape)
+            # print(" points.shape[1] : ", points.shape[1])
+            # print("to_categorical(label, num_category): ", to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).shape)
+
+            # print("to_categorical(label, num_category).repeat(1, points.shape[1], 1): ", to_categorical(torch.tensor(1).cuda(), num_category).repeat(1, points.shape[1], 1).shape)
+            # print("input shape: ", torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).repeat(1, points.shape[1], 1)], -1).shape)
+
+            seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
+            # seg_pred = classifier(torch.cat([points, to_categorical(label, num_category)], -1))
+
+
+          
+            seg_pred = seg_pred.contiguous().view(-1, num_part)
+            target = label.view(-1, 1)[:, 0]
+            pred_choice = seg_pred.data.max(1)[1]
+
+            correct = pred_choice.eq(target.data).cpu().sum()
+            mean_correct.append(correct.item() / (args.batch_size * args.num_point))
+            loss = criterion(seg_pred, target)
+            loss.backward()
+            optimizer.step()
+
+        train_instance_acc = np.mean(mean_correct)
+        logger.info('Train accuracy is: %.5f' % train_instance_acc)
+
+        with torch.no_grad():
+            test_metrics = {}
+            total_correct = 0
+            total_seen = 0
+            total_seen_class = [0 for _ in range(num_part)]
+            total_correct_class = [0 for _ in range(num_part)]
+            shape_ious = {cat: [] for cat in seg_classes.keys()}
+            seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
+
+            for cat in seg_classes.keys():
+                for label in seg_classes[cat]:
+                    seg_label_to_cat[label] = cat
+
+            classifier = classifier.eval()
+
+            for batch_id, (points, label) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
+                cur_batch_size, NUM_POINT, _ = points.size()
+                points, label = points.float().cuda(), label.long().cuda()
+                seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
+                cur_pred_val = seg_pred.cpu().data.numpy()
+                cur_pred_val_logits = cur_pred_val
+                cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
+                target = label.cpu().data.numpy()
+
+                # for i in range(cur_batch_size):
+                #     cat = seg_label_to_cat[target[i, 0]]
+                #     logits = cur_pred_val_logits[i, :, :]
+                #     cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
+
+                correct = np.sum(cur_pred_val == target)
+                total_correct += correct
+                total_seen += (cur_batch_size * NUM_POINT)
+
+                # for l in range(num_part):
+                #     total_seen_class[l] += np.sum(target == l)
+                #     total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
+
+                # for i in range(cur_batch_size):
+                #     segp = cur_pred_val[i, :]
+                #     segl = target[i, :]
+                #     cat = seg_label_to_cat[segl[0]]
+                #     part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
+                #     for l in seg_classes[cat]:
+                #         if (np.sum(segl == l) == 0) and (
+                #                 np.sum(segp == l) == 0):  # part is not present, no prediction as well
+                #             part_ious[l - seg_classes[cat][0]] = 1.0
+                #         else:
+                #             part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
+                #                 np.sum((segl == l) | (segp == l)))
+                #     shape_ious[cat].append(np.mean(part_ious))
+
+            # all_shape_ious = []
+            # for cat in shape_ious.keys():
+            #     for iou in shape_ious[cat]:
+            #         all_shape_ious.append(iou)
+            #     shape_ious[cat] = np.mean(shape_ious[cat])
+            # mean_shape_ious = np.mean(list(shape_ious.values()))
+            test_metrics['accuracy'] = total_correct / float(total_seen)
+            # test_metrics['class_avg_accuracy'] = np.mean(
+            #     np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float32))
+            
+            print("test_metrics['accuracy']: ", test_metrics['accuracy'])
+        #     for cat in sorted(shape_ious.keys()):
+        #         logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
+        #     test_metrics['class_avg_iou'] = mean_shape_ious
+        #     test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
+
+        # logger.info('Epoch %d test Accuracy: %f  Class avg mIOU: %f   Inctance avg mIOU: %f' % (
+        #     epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
+        # if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
+        #     logger.info('Save model...')
+        #     savepath = 'best_model.pth'
+        #     logger.info('Saving at %s' % savepath)
+        #     state = {
+        #         'epoch': epoch,
+        #         'train_acc': train_instance_acc,
+        #         'test_acc': test_metrics['accuracy'],
+        #         'class_avg_iou': test_metrics['class_avg_iou'],
+        #         'inctance_avg_iou': test_metrics['inctance_avg_iou'],
+        #         'model_state_dict': classifier.state_dict(),
+        #         'optimizer_state_dict': optimizer.state_dict(),
+        #     }
+        #     torch.save(state, savepath)
+        #     logger.info('Saving model....')
+
+        # if test_metrics['accuracy'] > best_acc:
+        #     best_acc = test_metrics['accuracy']
+        # if test_metrics['class_avg_iou'] > best_class_avg_iou:
+        #     best_class_avg_iou = test_metrics['class_avg_iou']
+        # if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
+        #     best_inctance_avg_iou = test_metrics['inctance_avg_iou']
+        # logger.info('Best accuracy is: %.5f' % best_acc)
+        # logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
+        # logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
+        global_epoch += 1
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file
diff --git a/train_partseg_single.py b/train_partseg_single.py
deleted file mode 100644
index 600499185749a4d7a182a6eaeade6661932b795f..0000000000000000000000000000000000000000
--- a/train_partseg_single.py
+++ /dev/null
@@ -1,250 +0,0 @@
-"""
-Author: Benny
-Date: Nov 2019
-"""
-import argparse
-import os
-import torch
-import datetime
-import logging
-import sys
-import importlib
-import shutil
-import provider
-import numpy as np
-
-from pathlib import Path
-from tqdm import tqdm
-from dataset import PartNormalDataset
-import hydra
-import omegaconf
-
-
-seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
-               'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],
-               'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],
-               'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
-seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
-for cat in seg_classes.keys():
-    for label in seg_classes[cat]:
-        seg_label_to_cat[label] = cat
-
-
-def inplace_relu(m):
-    classname = m.__class__.__name__
-    if classname.find('ReLU') != -1:
-        m.inplace=True
-
-def to_categorical(y, num_classes):
-    """ 1-hot encodes a tensor """
-    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
-    if (y.is_cuda):
-        return new_y.cuda()
-    return new_y
-
-@hydra.main(config_path='config', config_name='partseg')
-def main(args):
-    omegaconf.OmegaConf.set_struct(args, False)
-
-    '''HYPER PARAMETER'''
-    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
-
-    # print('GPU available: {}'.format(torch.cuda.is_available()))
-    logger = logging.getLogger(__name__)
-
-    # print(args.pretty())
-
-    # use pretty print to print the config
-    
-
-    root = hydra.utils.to_absolute_path('data/shapenetcore_partanno_segmentation_benchmark_v0_normal/')
-
-    TRAIN_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='trainval', normal_channel=args.normal)
-    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
-    TEST_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='test', normal_channel=args.normal)
-    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10)
-
-    '''MODEL LOADING'''
-    args.input_dim = (6 if args.normal else 3) + 16
-    args.num_class = 50
-    num_category = 16
-    num_part = args.num_class
-    shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
-
-    # print if gpu is available
-    logger.info('GPU available: {}'.format(torch.cuda.is_available()))
-
-    classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerSeg')(args).cuda()
-    criterion = torch.nn.CrossEntropyLoss()
-
-    try:
-        checkpoint = torch.load('best_model.pth')
-        start_epoch = checkpoint['epoch']
-        classifier.load_state_dict(checkpoint['model_state_dict'])
-        logger.info('Use pretrain model')
-    except:
-        logger.info('No existing model, starting training from scratch...')
-        start_epoch = 0
-
-    if args.optimizer == 'Adam':
-        optimizer = torch.optim.Adam(
-            classifier.parameters(),
-            lr=args.learning_rate,
-            betas=(0.9, 0.999),
-            eps=1e-08,
-            weight_decay=args.weight_decay
-        )
-    else:
-        optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)
-
-    def bn_momentum_adjust(m, momentum):
-        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
-            m.momentum = momentum
-
-    LEARNING_RATE_CLIP = 1e-5
-    MOMENTUM_ORIGINAL = 0.1
-    MOMENTUM_DECCAY = 0.5
-    MOMENTUM_DECCAY_STEP = args.step_size
-
-    best_acc = 0
-    global_epoch = 0
-    best_class_avg_iou = 0
-    best_inctance_avg_iou = 0
-
-    for epoch in range(start_epoch, args.epoch):
-        mean_correct = []
-
-        logger.info('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
-        '''Adjust learning rate and BN momentum'''
-        lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
-        logger.info('Learning rate:%f' % lr)
-        for param_group in optimizer.param_groups:
-            param_group['lr'] = lr
-        momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
-        if momentum < 0.01:
-            momentum = 0.01
-        print('BN momentum updated to: %f' % momentum)
-        classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
-        classifier = classifier.train()
-
-        '''learning one epoch'''
-        for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
-            points = points.data.numpy()
-            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
-            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
-            points = torch.Tensor(points)
-
-            points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
-            optimizer.zero_grad()
-
-            seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
-            seg_pred = seg_pred.contiguous().view(-1, num_part)
-            target = target.view(-1, 1)[:, 0]
-            pred_choice = seg_pred.data.max(1)[1]
-
-            correct = pred_choice.eq(target.data).cpu().sum()
-            mean_correct.append(correct.item() / (args.batch_size * args.num_point))
-            loss = criterion(seg_pred, target)
-            loss.backward()
-            optimizer.step()
-
-        train_instance_acc = np.mean(mean_correct)
-        logger.info('Train accuracy is: %.5f' % train_instance_acc)
-
-        with torch.no_grad():
-            test_metrics = {}
-            total_correct = 0
-            total_seen = 0
-            total_seen_class = [0 for _ in range(num_part)]
-            total_correct_class = [0 for _ in range(num_part)]
-            shape_ious = {cat: [] for cat in seg_classes.keys()}
-            seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
-
-            for cat in seg_classes.keys():
-                for label in seg_classes[cat]:
-                    seg_label_to_cat[label] = cat
-
-            classifier = classifier.eval()
-
-            for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
-                cur_batch_size, NUM_POINT, _ = points.size()
-                points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
-                seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
-                cur_pred_val = seg_pred.cpu().data.numpy()
-                cur_pred_val_logits = cur_pred_val
-                cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
-                target = target.cpu().data.numpy()
-
-                for i in range(cur_batch_size):
-                    cat = seg_label_to_cat[target[i, 0]]
-                    logits = cur_pred_val_logits[i, :, :]
-                    cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
-
-                correct = np.sum(cur_pred_val == target)
-                total_correct += correct
-                total_seen += (cur_batch_size * NUM_POINT)
-
-                for l in range(num_part):
-                    total_seen_class[l] += np.sum(target == l)
-                    total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
-
-                for i in range(cur_batch_size):
-                    segp = cur_pred_val[i, :]
-                    segl = target[i, :]
-                    cat = seg_label_to_cat[segl[0]]
-                    part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
-                    for l in seg_classes[cat]:
-                        if (np.sum(segl == l) == 0) and (
-                                np.sum(segp == l) == 0):  # part is not present, no prediction as well
-                            part_ious[l - seg_classes[cat][0]] = 1.0
-                        else:
-                            part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
-                                np.sum((segl == l) | (segp == l)))
-                    shape_ious[cat].append(np.mean(part_ious))
-
-            all_shape_ious = []
-            for cat in shape_ious.keys():
-                for iou in shape_ious[cat]:
-                    all_shape_ious.append(iou)
-                shape_ious[cat] = np.mean(shape_ious[cat])
-            mean_shape_ious = np.mean(list(shape_ious.values()))
-            test_metrics['accuracy'] = total_correct / float(total_seen)
-            test_metrics['class_avg_accuracy'] = np.mean(
-                np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float32))
-            for cat in sorted(shape_ious.keys()):
-                logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
-            test_metrics['class_avg_iou'] = mean_shape_ious
-            test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
-
-        logger.info('Epoch %d test Accuracy: %f  Class avg mIOU: %f   Inctance avg mIOU: %f' % (
-            epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
-        if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
-            logger.info('Save model...')
-            savepath = 'best_model.pth'
-            logger.info('Saving at %s' % savepath)
-            state = {
-                'epoch': epoch,
-                'train_acc': train_instance_acc,
-                'test_acc': test_metrics['accuracy'],
-                'class_avg_iou': test_metrics['class_avg_iou'],
-                'inctance_avg_iou': test_metrics['inctance_avg_iou'],
-                'model_state_dict': classifier.state_dict(),
-                'optimizer_state_dict': optimizer.state_dict(),
-            }
-            torch.save(state, savepath)
-            logger.info('Saving model....')
-
-        if test_metrics['accuracy'] > best_acc:
-            best_acc = test_metrics['accuracy']
-        if test_metrics['class_avg_iou'] > best_class_avg_iou:
-            best_class_avg_iou = test_metrics['class_avg_iou']
-        if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
-            best_inctance_avg_iou = test_metrics['inctance_avg_iou']
-        logger.info('Best accuracy is: %.5f' % best_acc)
-        logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
-        logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
-        global_epoch += 1
-
-
-if __name__ == '__main__':
-    main()
\ No newline at end of file
diff --git a/val_partseg.py b/val_partseg.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb4ec137c9a0aacc0fd70df29bb9d6606fde8fed
--- /dev/null
+++ b/val_partseg.py
@@ -0,0 +1,193 @@
+"""
+Author: Benny
+Date: Nov 2019
+"""
+import argparse
+import os
+import torch
+import datetime
+import logging
+import sys
+import importlib
+import shutil
+import provider
+import numpy as np
+
+from pathlib import Path
+from tqdm import tqdm
+from dataset import PartNormalDataset
+import hydra
+import omegaconf
+
+
+seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
+               'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],
+               'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],
+               'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
+seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
+for cat in seg_classes.keys():
+    for label in seg_classes[cat]:
+        seg_label_to_cat[label] = cat
+
+
+def inplace_relu(m):
+    classname = m.__class__.__name__
+    if classname.find('ReLU') != -1:
+        m.inplace=True
+
+def to_categorical(y, num_classes):
+    """ 1-hot encodes a tensor """
+    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
+    if (y.is_cuda):
+        return new_y.cuda()
+    return new_y
+
+@hydra.main(config_path='config', config_name='partseg')
+def main(args):
+    omegaconf.OmegaConf.set_struct(args, False)
+
+    '''HYPER PARAMETER'''
+    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
+
+    # print('GPU available: {}'.format(torch.cuda.is_available()))
+    logger = logging.getLogger(__name__)
+
+    # print(args.pretty())
+
+    # use pretty print to print the config
+    
+
+    root = hydra.utils.to_absolute_path('data/shapenetcore_partanno_segmentation_benchmark_v0_normal/')
+
+    TEST_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='test', normal_channel=args.normal)
+    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10)
+
+    '''MODEL LOADING'''
+    args.input_dim = (6 if args.normal else 3) + 16
+    args.num_class = 50
+    num_category = 16
+    num_part = args.num_class
+    shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
+
+    # print if gpu is available
+    logger.info('GPU available: {}'.format(torch.cuda.is_available()))
+
+    classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerSeg')(args).cuda()
+
+
+    # load pretrain model
+    checkpoint = torch.load('best_model.pth')
+    classifier.load_state_dict(checkpoint['model_state_dict'])
+    logger.info('Use pretrain model')
+
+    results_dir = hydra.utils.to_absolute_path('results')
+    # create folder to save the results las files
+    if not os.path.exists(results_dir):
+        os.mkdir(results_dir)
+    
+    with torch.no_grad():
+        test_metrics = {}
+        total_correct = 0
+        total_seen = 0
+        total_seen_class = [0 for _ in range(num_part)]
+        total_correct_class = [0 for _ in range(num_part)]
+        shape_ious = {cat: [] for cat in seg_classes.keys()}
+        seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
+
+        for cat in seg_classes.keys():
+            for label in seg_classes[cat]:
+                seg_label_to_cat[label] = cat
+
+        classifier = classifier.eval()
+
+        for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
+            cur_batch_size, NUM_POINT, _ = points.size()
+            points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
+            seg_pred = classifier(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
+            cur_pred_val = seg_pred.cpu().data.numpy()
+            cur_pred_val_logits = cur_pred_val
+            cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
+            target = target.cpu().data.numpy()
+
+            for i in range(cur_batch_size):
+                cat = seg_label_to_cat[target[i, 0]]
+                logits = cur_pred_val_logits[i, :, :]
+                cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
+
+            # get x,y,z coordinates of points
+            points = points.cpu().data.numpy()
+            points = points[:, :, 0:3]
+            # add predicted labels to points
+            points_pd = np.concatenate([points, np.expand_dims(cur_pred_val, axis=2)], axis=2)
+            points_gt = np.concatenate([points, np.expand_dims(target, axis=2)], axis=2)
+            # save points as text files in the results folder and preserve the same name as the original txt file
+            for i in range(cur_batch_size):
+                os.makedirs(os.path.join(results_dir,TEST_DATASET.datapath[batch_id * args.batch_size + i][1].split('/')[-2]), exist_ok=True)
+                np.savetxt(os.path.join(
+                    os.path.join(results_dir,TEST_DATASET.datapath[batch_id * args.batch_size + i][1].split('/')[-2]), 
+                    TEST_DATASET.datapath[batch_id * args.batch_size + i][1].split('/')[-1].replace('.txt', '_pred.txt')
+                    ),
+                           points_pd[i, :, :], fmt='%f %f %f %d')
+                
+                # copy original txt file to the results folder
+                shutil.copy(TEST_DATASET.datapath[batch_id * args.batch_size + i][1], os.path.join(
+                    os.path.join(results_dir,TEST_DATASET.datapath[batch_id * args.batch_size + i][1].split('/')[-2]), 
+                    TEST_DATASET.datapath[batch_id * args.batch_size + i][1].split('/')[-1]
+                    ))
+                
+                # save ground truth labels as text files in the results folder and preserve the same name as the original txt file
+                np.savetxt(os.path.join(
+                    os.path.join(results_dir,TEST_DATASET.datapath[batch_id * args.batch_size + i][1].split('/')[-2]),
+                    TEST_DATASET.datapath[batch_id * args.batch_size + i][1].split('/')[-1].replace('.txt', '_gt.txt')
+                    ),
+                            points_gt[i, :, :], fmt='%f %f %f %d')
+                
+
+            correct = np.sum(cur_pred_val == target)
+            total_correct += correct
+            total_seen += (cur_batch_size * NUM_POINT)
+
+            for l in range(num_part):
+                total_seen_class[l] += np.sum(target == l)
+                total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
+
+            for i in range(cur_batch_size):
+                segp = cur_pred_val[i, :]
+                segl = target[i, :]
+                cat = seg_label_to_cat[segl[0]]
+                part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
+                for l in seg_classes[cat]:
+                    if (np.sum(segl == l) == 0) and (
+                            np.sum(segp == l) == 0):  # part is not present, no prediction as well
+                        part_ious[l - seg_classes[cat][0]] = 1.0
+                    else:
+                        part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
+                            np.sum((segl == l) | (segp == l)))
+                shape_ious[cat].append(np.mean(part_ious))
+
+        all_shape_ious = []
+        for cat in shape_ious.keys():
+            for iou in shape_ious[cat]:
+                all_shape_ious.append(iou)
+            shape_ious[cat] = np.mean(shape_ious[cat])
+        mean_shape_ious = np.mean(list(shape_ious.values()))
+        test_metrics['accuracy'] = total_correct / float(total_seen)
+        test_metrics['class_avg_accuracy'] = np.mean(
+            np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float32))
+        for cat in sorted(shape_ious.keys()):
+            logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
+        test_metrics['class_avg_iou'] = mean_shape_ious
+        test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
+
+
+        logger.info('eval accuracy: %f' % (test_metrics['accuracy']))
+        logger.info('eval avg class acc: %f' % (test_metrics['class_avg_accuracy']))
+        logger.info('eval avg class IoU: %f' % (test_metrics['class_avg_iou']))
+        logger.info('eval avg instance IoU: %f' % (test_metrics['inctance_avg_iou']))
+
+
+
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file