diff --git a/train_partseg_forest.py b/train_partseg_forest.py index 2dad2c781af96ffb58b1f975dcb521512f161257..f3ba5a76384c2ef6dc927661134c785c46517263 100644 --- a/train_partseg_forest.py +++ b/train_partseg_forest.py @@ -21,7 +21,7 @@ import hydra import omegaconf -seg_classes = {'tree': [0,1,2,4]} +seg_classes = {'tree': [0,1,2,3]} seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} for cat in seg_classes.keys(): for label in seg_classes[cat]: @@ -186,76 +186,76 @@ def main(args): cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32) target = label.cpu().data.numpy() - # for i in range(cur_batch_size): - # cat = seg_label_to_cat[target[i, 0]] - # logits = cur_pred_val_logits[i, :, :] - # cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] + for i in range(cur_batch_size): + cat = seg_label_to_cat[target[i, 0]] + logits = cur_pred_val_logits[i, :, :] + cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] correct = np.sum(cur_pred_val == target) total_correct += correct total_seen += (cur_batch_size * NUM_POINT) - # for l in range(num_part): - # total_seen_class[l] += np.sum(target == l) - # total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l))) - - # for i in range(cur_batch_size): - # segp = cur_pred_val[i, :] - # segl = target[i, :] - # cat = seg_label_to_cat[segl[0]] - # part_ious = [0.0 for _ in range(len(seg_classes[cat]))] - # for l in seg_classes[cat]: - # if (np.sum(segl == l) == 0) and ( - # np.sum(segp == l) == 0): # part is not present, no prediction as well - # part_ious[l - seg_classes[cat][0]] = 1.0 - # else: - # part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( - # np.sum((segl == l) | (segp == l))) - # shape_ious[cat].append(np.mean(part_ious)) - - # all_shape_ious = [] - # for cat in shape_ious.keys(): - # for iou in shape_ious[cat]: - # all_shape_ious.append(iou) - # shape_ious[cat] = np.mean(shape_ious[cat]) - # mean_shape_ious = np.mean(list(shape_ious.values())) + for l in range(num_part): + total_seen_class[l] += np.sum(target == l) + total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l))) + + for i in range(cur_batch_size): + segp = cur_pred_val[i, :] + segl = target[i, :] + cat = seg_label_to_cat[segl[0]] + part_ious = [0.0 for _ in range(len(seg_classes[cat]))] + for l in seg_classes[cat]: + if (np.sum(segl == l) == 0) and ( + np.sum(segp == l) == 0): # part is not present, no prediction as well + part_ious[l - seg_classes[cat][0]] = 1.0 + else: + part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( + np.sum((segl == l) | (segp == l))) + shape_ious[cat].append(np.mean(part_ious)) + + all_shape_ious = [] + for cat in shape_ious.keys(): + for iou in shape_ious[cat]: + all_shape_ious.append(iou) + shape_ious[cat] = np.mean(shape_ious[cat]) + mean_shape_ious = np.mean(list(shape_ious.values())) test_metrics['accuracy'] = total_correct / float(total_seen) - # test_metrics['class_avg_accuracy'] = np.mean( - # np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float32)) + test_metrics['class_avg_accuracy'] = np.mean( + np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float32)) print("test_metrics['accuracy']: ", test_metrics['accuracy']) - # for cat in sorted(shape_ious.keys()): - # logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat])) - # test_metrics['class_avg_iou'] = mean_shape_ious - # test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) - - # logger.info('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( - # epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou'])) - # if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou): - # logger.info('Save model...') - # savepath = 'best_model.pth' - # logger.info('Saving at %s' % savepath) - # state = { - # 'epoch': epoch, - # 'train_acc': train_instance_acc, - # 'test_acc': test_metrics['accuracy'], - # 'class_avg_iou': test_metrics['class_avg_iou'], - # 'inctance_avg_iou': test_metrics['inctance_avg_iou'], - # 'model_state_dict': classifier.state_dict(), - # 'optimizer_state_dict': optimizer.state_dict(), - # } - # torch.save(state, savepath) - # logger.info('Saving model....') - - # if test_metrics['accuracy'] > best_acc: - # best_acc = test_metrics['accuracy'] - # if test_metrics['class_avg_iou'] > best_class_avg_iou: - # best_class_avg_iou = test_metrics['class_avg_iou'] - # if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: - # best_inctance_avg_iou = test_metrics['inctance_avg_iou'] - # logger.info('Best accuracy is: %.5f' % best_acc) - # logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou) - # logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou) + for cat in sorted(shape_ious.keys()): + logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat])) + test_metrics['class_avg_iou'] = mean_shape_ious + test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) + + logger.info('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( + epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou'])) + if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou): + logger.info('Save model...') + savepath = 'best_model_forest.pth' + logger.info('Saving at %s' % savepath) + state = { + 'epoch': epoch, + 'train_acc': train_instance_acc, + 'test_acc': test_metrics['accuracy'], + 'class_avg_iou': test_metrics['class_avg_iou'], + 'inctance_avg_iou': test_metrics['inctance_avg_iou'], + 'model_state_dict': classifier.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + } + torch.save(state, savepath) + logger.info('Saving model....') + + if test_metrics['accuracy'] > best_acc: + best_acc = test_metrics['accuracy'] + if test_metrics['class_avg_iou'] > best_class_avg_iou: + best_class_avg_iou = test_metrics['class_avg_iou'] + if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: + best_inctance_avg_iou = test_metrics['inctance_avg_iou'] + logger.info('Best accuracy is: %.5f' % best_acc) + logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou) + logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou) global_epoch += 1