diff --git a/train_partseg_forest.py b/train_partseg_forest.py
index 75c961a2597873972e2dde787753f43e426d0b8d..2dad2c781af96ffb58b1f975dcb521512f161257 100644
--- a/train_partseg_forest.py
+++ b/train_partseg_forest.py
@@ -144,7 +144,7 @@ def main(args):
             # print("to_categorical(label, num_category).repeat(1, points.shape[1], 1): ", to_categorical(torch.tensor(1).cuda(), num_category).repeat(1, points.shape[1], 1).shape)
             # print("input shape: ", torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).repeat(1, points.shape[1], 1)], -1).shape)
 
-            seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
+            seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
             # seg_pred = classifier(torch.cat([points, to_categorical(label, num_category)], -1))
 
 
@@ -180,7 +180,7 @@ def main(args):
             for batch_id, (points, label) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
                 cur_batch_size, NUM_POINT, _ = points.size()
                 points, label = points.float().cuda(), label.long().cuda()
-                seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
+                seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
                 cur_pred_val = seg_pred.cpu().data.numpy()
                 cur_pred_val_logits = cur_pred_val
                 cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)