diff --git a/train_partseg_forest.py b/train_partseg_forest.py index f3ba5a76384c2ef6dc927661134c785c46517263..a2c83a12e073fe934f1c11b18a5b5dc42c28c1f8 100644 --- a/train_partseg_forest.py +++ b/train_partseg_forest.py @@ -55,8 +55,8 @@ def main(args): # use pretty print to print the config - train_dataset = hydra.utils.to_absolute_path('data/forest_txt/validation_txt/') - test_dataset = hydra.utils.to_absolute_path('data/forest_txt/validation_txt/') + train_dataset = hydra.utils.to_absolute_path('data/forest_txt/train_txt/') + test_dataset = hydra.utils.to_absolute_path('data/forest_txt/test_txt/') TRAIN_DATASET = Dataset(root=train_dataset, npoints=args.num_point, normal_channel=args.normal) trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)