diff --git a/train_partseg_forest.py b/train_partseg_forest.py index 75c961a2597873972e2dde787753f43e426d0b8d..2dad2c781af96ffb58b1f975dcb521512f161257 100644 --- a/train_partseg_forest.py +++ b/train_partseg_forest.py @@ -144,7 +144,7 @@ def main(args): # print("to_categorical(label, num_category).repeat(1, points.shape[1], 1): ", to_categorical(torch.tensor(1).cuda(), num_category).repeat(1, points.shape[1], 1).shape) # print("input shape: ", torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).repeat(1, points.shape[1], 1)], -1).shape) - seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) + seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) # seg_pred = classifier(torch.cat([points, to_categorical(label, num_category)], -1)) @@ -180,7 +180,7 @@ def main(args): for batch_id, (points, label) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): cur_batch_size, NUM_POINT, _ = points.size() points, label = points.float().cuda(), label.long().cuda() - seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) + seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) cur_pred_val = seg_pred.cpu().data.numpy() cur_pred_val_logits = cur_pred_val cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)