From 3e0285ba96bfa48c4a5fd2b0d8ed1badf1df512b Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Wed, 8 Mar 2023 14:24:04 +0100 Subject: [PATCH] update training forest - first version --- train_partseg_forest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_partseg_forest.py b/train_partseg_forest.py index 75c961a..2dad2c7 100644 --- a/train_partseg_forest.py +++ b/train_partseg_forest.py @@ -144,7 +144,7 @@ def main(args): # print("to_categorical(label, num_category).repeat(1, points.shape[1], 1): ", to_categorical(torch.tensor(1).cuda(), num_category).repeat(1, points.shape[1], 1).shape) # print("input shape: ", torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).repeat(1, points.shape[1], 1)], -1).shape) - seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) + seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) # seg_pred = classifier(torch.cat([points, to_categorical(label, num_category)], -1)) @@ -180,7 +180,7 @@ def main(args): for batch_id, (points, label) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): cur_batch_size, NUM_POINT, _ = points.size() points, label = points.float().cuda(), label.long().cuda() - seg_pred = classifier(torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) + seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) cur_pred_val = seg_pred.cpu().data.numpy() cur_pred_val_logits = cur_pred_val cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32) -- GitLab