diff --git a/train_partseg_forest.py b/train_partseg_forest.py index 9bb4a8561e44b1fc8521bc04cf8ec3fb3e510d31..4265d52430488b905d044388b47a5dcc4b00e062 100644 --- a/train_partseg_forest.py +++ b/train_partseg_forest.py @@ -10,12 +10,12 @@ import logging import sys import importlib import shutil + +import wandb import provider import numpy as np -from pathlib import Path from tqdm import tqdm -from dataset import PartNormalDataset from nibio_transformer_semantic.dataset import Dataset import hydra import omegaconf @@ -44,17 +44,14 @@ def to_categorical(y, num_classes): def main(args): omegaconf.OmegaConf.set_struct(args, False) + conf = omegaconf.OmegaConf.to_container(args, resolve=True) + + wandb.init(project="forest-point-transformer", entity="maciej-wielgosz-nibio", config=conf) + '''HYPER PARAMETER''' os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) - - # print('GPU available: {}'.format(torch.cuda.is_available())) logger = logging.getLogger(__name__) - # print(args.pretty()) - - # use pretty print to print the config - - train_dataset = hydra.utils.to_absolute_path('data/forest_txt/train_txt/') test_dataset = hydra.utils.to_absolute_path('data/forest_txt/test_txt/') @@ -74,7 +71,6 @@ def main(args): '''MODEL LOADING''' args.input_dim = (6 if args.normal else 3) + 16 args.num_class = 4 - num_category = 1 num_part = args.num_class shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.') @@ -143,20 +139,8 @@ def main(args): points, label = points.float().cuda(), label.long().cuda() optimizer.zero_grad() - - # print("points shape ..: ", points.shape) - # print("label shape ...: ", label.shape) - # print(" points.shape[1] : ", points.shape[1]) - # print("to_categorical(label, num_category): ", to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).shape) - - # print("to_categorical(label, num_category).repeat(1, points.shape[1], 1): ", to_categorical(torch.tensor(1).cuda(), num_category).repeat(1, points.shape[1], 1).shape) - # print("input shape: ", torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).repeat(1, points.shape[1], 1)], -1).shape) - seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1)) - # seg_pred = classifier(torch.cat([points, to_categorical(label, num_category)], -1)) - - - + seg_pred = seg_pred.contiguous().view(-1, num_part) target = label.view(-1, 1)[:, 0] pred_choice = seg_pred.data.max(1)[1] @@ -164,11 +148,18 @@ def main(args): correct = pred_choice.eq(target.data).cpu().sum() mean_correct.append(correct.item() / (args.batch_size * args.num_point)) loss = criterion(seg_pred, target) + # add loss to wandb + wandb.log({'train_loss': loss.item()}) loss.backward() optimizer.step() + train_instance_acc = np.mean(mean_correct) logger.info('Train accuracy is: %.5f' % train_instance_acc) + # add train accuracy to wandb + wandb.log({'train_accuracy': train_instance_acc}) + # add epoch to wandb + wandb.log({'epoch': epoch}) with torch.no_grad(): test_metrics = {} @@ -264,8 +255,15 @@ def main(args): logger.info('Best accuracy is: %.5f' % best_acc) logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou) logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou) + wandb.log({'test_accuracy': test_metrics['accuracy']}) + wandb.log({'test_class_avg_iou': test_metrics['class_avg_iou']}) + wandb.log({'test_inctance_avg_iou': test_metrics['inctance_avg_iou']}) + wandb.log({'best_accuracy': best_acc}) + wandb.log({'best_class_avg_iou': best_class_avg_iou}) + wandb.log({'best_inctance_avg_iou': best_inctance_avg_iou}) global_epoch += 1 + if __name__ == '__main__': main() \ No newline at end of file