Skip to content
Snippets Groups Projects
Commit 3e0285ba authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

update training forest - first version

parent b889264d
No related branches found
No related tags found
No related merge requests found
......@@ -144,7 +144,7 @@ def main(args):
# print("to_categorical(label, num_category).repeat(1, points.shape[1], 1): ", to_categorical(torch.tensor(1).cuda(), num_category).repeat(1, points.shape[1], 1).shape)
# print("input shape: ", torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).repeat(1, points.shape[1], 1)], -1).shape)
seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
# seg_pred = classifier(torch.cat([points, to_categorical(label, num_category)], -1))
......@@ -180,7 +180,7 @@ def main(args):
for batch_id, (points, label) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
cur_batch_size, NUM_POINT, _ = points.size()
points, label = points.float().cuda(), label.long().cuda()
seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
cur_pred_val = seg_pred.cpu().data.numpy()
cur_pred_val_logits = cur_pred_val
cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment