diff --git a/.gitignore b/.gitignore index 6c1eb2247e93d0fd6f61da12c4c9327db9c93222..eb9d135465ba2c718018923475ec006bf400600e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ __pycache__/ modelnet40_normal_resampled/ outputs/ -log/ \ No newline at end of file +log/ +data/ \ No newline at end of file diff --git a/config/partseg.yaml b/config/partseg.yaml index dccaca0e0ea6b68140684c389ae9a7d544c751ce..447566fc6486b629ad041edff6b352f507623c47 100644 --- a/config/partseg.yaml +++ b/config/partseg.yaml @@ -1,7 +1,7 @@ batch_size: 16 epoch: 200 learning_rate: 1e-3 -gpu: 1 +gpu: 0 num_point: 1024 optimizer: Adam weight_decay: 1e-4 diff --git a/train_partseg.py b/train_partseg.py index a1e9eaccb53823e5893e0dc2c766a466007b04cd..600499185749a4d7a182a6eaeade6661932b795f 100644 --- a/train_partseg.py +++ b/train_partseg.py @@ -48,9 +48,14 @@ def main(args): '''HYPER PARAMETER''' os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + + # print('GPU available: {}'.format(torch.cuda.is_available())) logger = logging.getLogger(__name__) - print(args.pretty()) + # print(args.pretty()) + + # use pretty print to print the config + root = hydra.utils.to_absolute_path('data/shapenetcore_partanno_segmentation_benchmark_v0_normal/') @@ -66,6 +71,9 @@ def main(args): num_part = args.num_class shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.') + # print if gpu is available + logger.info('GPU available: {}'.format(torch.cuda.is_available())) + classifier = getattr(importlib.import_module('models.{}.model'.format(args.model.name)), 'PointTransformerSeg')(args).cuda() criterion = torch.nn.CrossEntropyLoss() @@ -202,7 +210,7 @@ def main(args): mean_shape_ious = np.mean(list(shape_ious.values())) test_metrics['accuracy'] = total_correct / float(total_seen) test_metrics['class_avg_accuracy'] = np.mean( - np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) + np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float32)) for cat in sorted(shape_ious.keys()): logger.info('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat])) test_metrics['class_avg_iou'] = mean_shape_ious