diff --git a/train_partseg_forest.py b/train_partseg_forest.py
index 2dad2c781af96ffb58b1f975dcb521512f161257..f3ba5a76384c2ef6dc927661134c785c46517263 100644
--- a/train_partseg_forest.py
+++ b/train_partseg_forest.py
@@ -21,7 +21,7 @@ import hydra
 import omegaconf
 
 
-seg_classes = {'tree': [0,1,2,4]}
+seg_classes = {'tree': [0,1,2,3]}
 seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
 for cat in seg_classes.keys():
     for label in seg_classes[cat]:
@@ -186,76 +186,76 @@ def main(args):
                 cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
                 target = label.cpu().data.numpy()
 
-                # for i in range(cur_batch_size):
-                #     cat = seg_label_to_cat[target[i, 0]]
-                #     logits = cur_pred_val_logits[i, :, :]
-                #     cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
+                for i in range(cur_batch_size):
+                    cat = seg_label_to_cat[target[i, 0]]
+                    logits = cur_pred_val_logits[i, :, :]
+                    cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
 
                 correct = np.sum(cur_pred_val == target)
                 total_correct += correct
                 total_seen += (cur_batch_size * NUM_POINT)
 
-                # for l in range(num_part):
-                #     total_seen_class[l] += np.sum(target == l)
-                #     total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
-
-                # for i in range(cur_batch_size):
-                #     segp = cur_pred_val[i, :]
-                #     segl = target[i, :]
-                #     cat = seg_label_to_cat[segl[0]]
-                #     part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
-                #     for l in seg_classes[cat]:
-                #         if (np.sum(segl == l) == 0) and (
-                #                 np.sum(segp == l) == 0):  # part is not present, no prediction as well
-                #             part_ious[l - seg_classes[cat][0]] = 1.0
-                #         else:
-                #             part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
-                #                 np.sum((segl == l) | (segp == l)))
-                #     shape_ious[cat].append(np.mean(part_ious))
-
-            # all_shape_ious = []
-            # for cat in shape_ious.keys():
-            #     for iou in shape_ious[cat]:
-            #         all_shape_ious.append(iou)
-            #     shape_ious[cat] = np.mean(shape_ious[cat])
-            # mean_shape_ious = np.mean(list(shape_ious.values()))
+                for l in range(num_part):
+                    total_seen_class[l] += np.sum(target == l)
+                    total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
+
+                for i in range(cur_batch_size):
+                    segp = cur_pred_val[i, :]
+                    segl = target[i, :]
+                    cat = seg_label_to_cat[segl[0]]
+                    part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
+                    for l in seg_classes[cat]:
+                        if (np.sum(segl == l) == 0) and (
+                                np.sum(segp == l) == 0):  # part is not present, no prediction as well
+                            part_ious[l - seg_classes[cat][0]] = 1.0
+                        else:
+                            part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
+                                np.sum((segl == l) | (segp == l)))
+                    shape_ious[cat].append(np.mean(part_ious))
+
+            all_shape_ious = []
+            for cat in shape_ious.keys():
+                for iou in shape_ious[cat]:
+                    all_shape_ious.append(iou)
+                shape_ious[cat] = np.mean(shape_ious[cat])
+            mean_shape_ious = np.mean(list(shape_ious.values()))
             test_metrics['accuracy'] = total_correct / float(total_seen)
-            # test_metrics['class_avg_accuracy'] = np.mean(
-            #     np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float32))
+            test_metrics['class_avg_accuracy'] = np.mean(
+                np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float32))
             
             print("test_metrics['accuracy']: ", test_metrics['accuracy'])
-        #     for cat in sorted(shape_ious.keys()):
-        #         logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
-        #     test_metrics['class_avg_iou'] = mean_shape_ious
-        #     test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
-
-        # logger.info('Epoch %d test Accuracy: %f  Class avg mIOU: %f   Inctance avg mIOU: %f' % (
-        #     epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
-        # if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
-        #     logger.info('Save model...')
-        #     savepath = 'best_model.pth'
-        #     logger.info('Saving at %s' % savepath)
-        #     state = {
-        #         'epoch': epoch,
-        #         'train_acc': train_instance_acc,
-        #         'test_acc': test_metrics['accuracy'],
-        #         'class_avg_iou': test_metrics['class_avg_iou'],
-        #         'inctance_avg_iou': test_metrics['inctance_avg_iou'],
-        #         'model_state_dict': classifier.state_dict(),
-        #         'optimizer_state_dict': optimizer.state_dict(),
-        #     }
-        #     torch.save(state, savepath)
-        #     logger.info('Saving model....')
-
-        # if test_metrics['accuracy'] > best_acc:
-        #     best_acc = test_metrics['accuracy']
-        # if test_metrics['class_avg_iou'] > best_class_avg_iou:
-        #     best_class_avg_iou = test_metrics['class_avg_iou']
-        # if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
-        #     best_inctance_avg_iou = test_metrics['inctance_avg_iou']
-        # logger.info('Best accuracy is: %.5f' % best_acc)
-        # logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
-        # logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
+            for cat in sorted(shape_ious.keys()):
+                logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
+            test_metrics['class_avg_iou'] = mean_shape_ious
+            test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
+
+        logger.info('Epoch %d test Accuracy: %f  Class avg mIOU: %f   Inctance avg mIOU: %f' % (
+            epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
+        if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
+            logger.info('Save model...')
+            savepath = 'best_model_forest.pth'
+            logger.info('Saving at %s' % savepath)
+            state = {
+                'epoch': epoch,
+                'train_acc': train_instance_acc,
+                'test_acc': test_metrics['accuracy'],
+                'class_avg_iou': test_metrics['class_avg_iou'],
+                'inctance_avg_iou': test_metrics['inctance_avg_iou'],
+                'model_state_dict': classifier.state_dict(),
+                'optimizer_state_dict': optimizer.state_dict(),
+            }
+            torch.save(state, savepath)
+            logger.info('Saving model....')
+
+        if test_metrics['accuracy'] > best_acc:
+            best_acc = test_metrics['accuracy']
+        if test_metrics['class_avg_iou'] > best_class_avg_iou:
+            best_class_avg_iou = test_metrics['class_avg_iou']
+        if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
+            best_inctance_avg_iou = test_metrics['inctance_avg_iou']
+        logger.info('Best accuracy is: %.5f' % best_acc)
+        logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
+        logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
         global_epoch += 1