diff --git a/nibio_transformer_semantic/dataset.py b/nibio_transformer_semantic/dataset.py
index 447c8159265e66ed45c049a875781156b88f0acc..87cedfa2f828cf82f01721aeb34114c0081bb7d4 100644
--- a/nibio_transformer_semantic/dataset.py
+++ b/nibio_transformer_semantic/dataset.py
@@ -1,43 +1,57 @@
+import numpy as np
 import os
 from torch.utils.data import Dataset
-import logging
-
-
-logging.basicConfig(level=logging.INFO)
+import torch
+from pointnet_util import farthest_point_sample, pc_normalize
+import json
 
 
 class Dataset(Dataset):
-    '''
-    Dataset class for the Nibio Transformer Semantic Segmentation.
-    There is an assumption that the data is stored in the following format:
-    root
-    ├── 0001.las
-    ├── 0002.las
-    ├── 0003.las
-    ├── ...
-    '''
     def __init__(self,
-                 root,  # root directory of the dataset
-                 npoint=1024,  # number of points to sample
-                 uniform=False,  # sample points uniformly or not
-                 normal_channel=True,  # use normal channel or not)
-                 cache_size=15000,
-                 verbose=True):
-
+                 root='./data/forest_txt/train_txt',
+                 npoints=2500,
+                 normal_channel=False
+                 ):
+        
         self.root = root
-        self.npoints = npoint
-        self.uniform = uniform
+        self.npoints = npoints
         self.normal_channel = normal_channel
-        self.cache_size = cache_size  # how many data points to cache in memory
-        self.cache = {}  # store the data points
-        # get paths to all the files
-        self.datapath = [os.path.join(self.root, f) for f in os.listdir(self.root)]
+        self.datapath = []
+
+
+        for fn in os.listdir(root):
+            if fn.endswith('.txt'):
+                self.datapath.append(os.path.join(root, fn))
 
-        if verbose:
-            logging.info('The size of the dataset is %d' % len(self.datapath))
+        self.cache = {}  # from index to (point_set, cls, seg) tuple
+        self.cache_size = 20000
+
+    def __getitem__(self, index):
+        if index in self.cache:
+            point_set, seg = self.cache[index]
+        else:
+            fn = self.datapath[index]
+            data = np.loadtxt(fn).astype(np.float32)
+            if not self.normal_channel:
+                point_set = data[:, 0:3]
+            else:
+                point_set = data[:, 0:6]
+            seg = data[:, -2].astype(np.int32) # -1 is the instance label
+            if len(self.cache) < self.cache_size:
+                self.cache[index] = (point_set, seg)
+        # point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
+
+        choice = np.random.choice(len(seg), self.npoints, replace=True)
+        # resample
+        point_set = point_set[choice, :]
+        seg = seg[choice]
+
+
+        return point_set, seg
 
     def __len__(self):
         return len(self.datapath)
 
-    def __getitem__(self, index):
-        return self._get_item(index)
\ No newline at end of file
+
+if __name__ == '__main__':
+    pass
diff --git a/nibio_transformer_semantic/pipeline.json b/nibio_transformer_semantic/pipeline.json
deleted file mode 100644
index 5b068dd97d791acfd7c95e1ea813999f99af9ff5..0000000000000000000000000000000000000000
--- a/nibio_transformer_semantic/pipeline.json
+++ /dev/null
@@ -1 +0,0 @@
-{"pipeline": [{"type": "readers.las", "filename": "/home/nibio/mutable-outside-world/code/Point-Transformers/data/pointclouds/Plot89_tile_0_0.las"}, {"type": "filters.split", "capacity": 1024, "filename": "output_{}.las"}, {"type": "writer.text", "filename": "/home/nibio/mutable-outside-world/code/Point-Transformers/data/pointclouds_split_{}.txt"}]}
\ No newline at end of file
diff --git a/nibio_transformer_semantic/prepare_data.py b/nibio_transformer_semantic/prepare_data.py
deleted file mode 100644
index 7633f01253e76f8199fb0b4dca4967ab7568f8d6..0000000000000000000000000000000000000000
--- a/nibio_transformer_semantic/prepare_data.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import json
-import os
-import numpy as np
-import laspy
-
-class PrepareData:
-
-    def __init__(self, input_folder, outout_folder, npoint=1024,verbose=True):
-        self.input_folder = input_folder
-        self.output_folder = outout_folder
-        self.npoints = npoint
-        self.verbose = verbose
-
-
-    def process(self):
-        # get paths to all the files
-        self.datapath_input = [os.path.join(self.input_folder, f) for f in os.listdir(self.input_folder)]
-        if self.verbose:
-            print('The size of the dataset is %d' % len(self.datapath_input))
-
-        # create output folder
-        if not os.path.exists(self.output_folder):
-            os.makedirs(self.output_folder)
-
-        # create names for output files
-        self.datapath_output = [os.path.join(self.output_folder, f) for f in os.listdir(self.input_folder)]
-
-        # process each file
-        for input_file, output_file in zip(self.datapath_input, self.datapath_output):
-            self.process_file(capacity=self.npoints, input_file=input_file , output_file=output_file)
-
-    def process_file(self, capacity, input_file, output_file):
-
-        os.system("pdal split --capacity {} --input {} --output {} --writers.las.extra_dims=all".format(capacity, input_file, output_file))
-
-
-        
-
-if __name__ == '__main__':
-    input_folder = "/home/nibio/mutable-outside-world/code/Point-Transformers/data/pointclouds"
-    output_folder = "/home/nibio/mutable-outside-world/code/Point-Transformers/data/pointclouds_split"
-    npoint = 1024
-    prepare_data = PrepareData(input_folder, output_folder, npoint)
-    prepare_data.process()
-
-
-
diff --git a/val_parseg_forest.py b/val_parseg_forest.py
new file mode 100644
index 0000000000000000000000000000000000000000..fee7f1c46b0977179062976be6d8875aa0f28b57
--- /dev/null
+++ b/val_parseg_forest.py
@@ -0,0 +1,192 @@
+"""
+Author: Benny
+Date: Nov 2019
+"""
+import argparse
+import os
+import torch
+import datetime
+import logging
+import sys
+import importlib
+import shutil
+import provider
+import numpy as np
+
+from pathlib import Path
+from tqdm import tqdm
+from dataset import PartNormalDataset
+from nibio_transformer_semantic.dataset import Dataset
+import hydra
+import omegaconf
+
+
+seg_classes = {'tree': [0,1,2,3]}
+seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
+for cat in seg_classes.keys():
+    for label in seg_classes[cat]:
+        seg_label_to_cat[label] = cat
+
+
+def inplace_relu(m):
+    classname = m.__class__.__name__
+    if classname.find('ReLU') != -1:
+        m.inplace=True
+
+def to_categorical(y, num_classes):
+    """ 1-hot encodes a tensor """
+    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
+    if (y.is_cuda):
+        return new_y.cuda()
+    return new_y
+
+@hydra.main(config_path='config', config_name='partseg')
+def main(args):
+    omegaconf.OmegaConf.set_struct(args, False)
+
+    '''HYPER PARAMETER'''
+    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
+
+    # print('GPU available: {}'.format(torch.cuda.is_available()))
+    logger = logging.getLogger(__name__)
+
+    test_dataset = hydra.utils.to_absolute_path('data/forest_txt/validation_txt/')
+
+    TEST_DATASET = Dataset(root=test_dataset, npoints=args.num_point, normal_channel=args.normal)
+    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10)
+
+    '''MODEL LOADING'''
+    args.input_dim = (6 if args.normal else 3) + 16
+    args.num_class = 4
+    num_part = args.num_class
+    shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
+
+    # print if gpu is available
+    logger.info('GPU available: {}'.format(torch.cuda.is_available()))
+
+    classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerSeg')(args).cuda()
+  
+    checkpoint = torch.load('best_model_forest.pth')
+    classifier.load_state_dict(checkpoint['model_state_dict'])
+    logger.info('Use pretrain model')
+  
+    best_acc = 0
+    best_class_avg_iou = 0
+    best_inctance_avg_iou = 0
+
+    results_dir = hydra.utils.to_absolute_path('results')
+    # create folder to save the results las files
+    if not os.path.exists(results_dir):
+        os.mkdir(results_dir)
+
+
+    with torch.no_grad():
+        test_metrics = {}
+        total_correct = 0
+        total_seen = 0
+        total_seen_class = [0 for _ in range(num_part)]
+        total_correct_class = [0 for _ in range(num_part)]
+        shape_ious = {cat: [] for cat in seg_classes.keys()}
+        seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
+
+        for cat in seg_classes.keys():
+            for label in seg_classes[cat]:
+                seg_label_to_cat[label] = cat
+
+        classifier = classifier.eval()
+
+        for batch_id, (points, label) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
+            cur_batch_size, NUM_POINT, _ = points.size()
+            points, label = points.float().cuda(), label.long().cuda()
+            seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
+            cur_pred_val = seg_pred.cpu().data.numpy()
+            cur_pred_val_logits = cur_pred_val
+            cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
+            target = label.cpu().data.numpy()
+
+            for i in range(cur_batch_size):
+                cat = seg_label_to_cat[target[i, 0]]
+                logits = cur_pred_val_logits[i, :, :]
+                cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
+
+
+
+            # get x,y,z coordinates of points
+            points = points.cpu().data.numpy()
+            points = points[:, :, 0:3]
+            # add predicted labels to points
+            points_pd = np.concatenate([points, np.expand_dims(cur_pred_val, axis=2)], axis=2)
+            points_gt = np.concatenate([points, np.expand_dims(target, axis=2)], axis=2)
+            # save points as text files in the results folder and preserve the same name as the original txt file
+            for i in range(cur_batch_size):
+                np.savetxt(os.path.join(results_dir,
+                    TEST_DATASET.datapath[batch_id * args.batch_size + i].split('/')[-1].replace('.txt', '_pred.txt')
+                    ),
+                           points_pd[i, :, :], fmt='%f %f %f %d')
+                
+                # copy original txt file to the results folder
+                shutil.copy(TEST_DATASET.datapath[batch_id * args.batch_size + i], 
+                    os.path.join(results_dir, TEST_DATASET.datapath[batch_id * args.batch_size + i].split('/')[-1]
+                    ))
+                
+                # save ground truth labels as text files in the results folder and preserve the same name as the original txt file
+                np.savetxt(os.path.join(
+                    results_dir,
+                    TEST_DATASET.datapath[batch_id * args.batch_size + i].split('/')[-1].replace('.txt', '_gt.txt')
+                    ),
+                            points_gt[i, :, :], fmt='%f %f %f %d')
+
+
+            correct = np.sum(cur_pred_val == target)
+            total_correct += correct
+            total_seen += (cur_batch_size * NUM_POINT)
+
+            for l in range(num_part):
+                total_seen_class[l] += np.sum(target == l)
+                total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
+
+            for i in range(cur_batch_size):
+                segp = cur_pred_val[i, :]
+                segl = target[i, :]
+                cat = seg_label_to_cat[segl[0]]
+                part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
+                for l in seg_classes[cat]:
+                    if (np.sum(segl == l) == 0) and (
+                            np.sum(segp == l) == 0):  # part is not present, no prediction as well
+                        part_ious[l - seg_classes[cat][0]] = 1.0
+                    else:
+                        part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
+                            np.sum((segl == l) | (segp == l)))
+                shape_ious[cat].append(np.mean(part_ious))
+
+        all_shape_ious = []
+        for cat in shape_ious.keys():
+            for iou in shape_ious[cat]:
+                all_shape_ious.append(iou)
+            shape_ious[cat] = np.mean(shape_ious[cat])
+        mean_shape_ious = np.mean(list(shape_ious.values()))
+        test_metrics['accuracy'] = total_correct / float(total_seen)
+        test_metrics['class_avg_accuracy'] = np.mean(
+            np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float32))
+        
+        print("test_metrics['accuracy']: ", test_metrics['accuracy'])
+        for cat in sorted(shape_ious.keys()):
+            logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
+        test_metrics['class_avg_iou'] = mean_shape_ious
+        test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
+
+  
+    if test_metrics['accuracy'] > best_acc:
+        best_acc = test_metrics['accuracy']
+    if test_metrics['class_avg_iou'] > best_class_avg_iou:
+        best_class_avg_iou = test_metrics['class_avg_iou']
+    if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
+        best_inctance_avg_iou = test_metrics['inctance_avg_iou']
+    logger.info('Best accuracy is: %.5f' % best_acc)
+    logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
+    logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
+
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file
diff --git a/val_partseg.py b/val_partseg.py
index bb4ec137c9a0aacc0fd70df29bb9d6606fde8fed..f0e593e6b4acbe74b34a30fb76a08ce1e505df76 100644
--- a/val_partseg.py
+++ b/val_partseg.py
@@ -186,8 +186,5 @@ def main(args):
         logger.info('eval avg instance IoU: %f' % (test_metrics['inctance_avg_iou']))
 
 
-
-
-
 if __name__ == '__main__':
     main()
\ No newline at end of file