Skip to content
Snippets Groups Projects
Commit b08eadc7 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

update training for forest transformer

parent e7e6ceac
Branches
No related tags found
No related merge requests found
......@@ -10,12 +10,12 @@ import logging
import sys
import importlib
import shutil
import wandb
import provider
import numpy as np
from pathlib import Path
from tqdm import tqdm
from dataset import PartNormalDataset
from nibio_transformer_semantic.dataset import Dataset
import hydra
import omegaconf
......@@ -44,17 +44,14 @@ def to_categorical(y, num_classes):
def main(args):
omegaconf.OmegaConf.set_struct(args, False)
conf = omegaconf.OmegaConf.to_container(args, resolve=True)
wandb.init(project="forest-point-transformer", entity="maciej-wielgosz-nibio", config=conf)
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
# print('GPU available: {}'.format(torch.cuda.is_available()))
logger = logging.getLogger(__name__)
# print(args.pretty())
# use pretty print to print the config
train_dataset = hydra.utils.to_absolute_path('data/forest_txt/train_txt/')
test_dataset = hydra.utils.to_absolute_path('data/forest_txt/test_txt/')
......@@ -74,7 +71,6 @@ def main(args):
'''MODEL LOADING'''
args.input_dim = (6 if args.normal else 3) + 16
args.num_class = 4
num_category = 1
num_part = args.num_class
shutil.copy(hydra.utils.to_absolute_path('models/{}/model.py'.format(args.model.name)), '.')
......@@ -143,20 +139,8 @@ def main(args):
points, label = points.float().cuda(), label.long().cuda()
optimizer.zero_grad()
# print("points shape ..: ", points.shape)
# print("label shape ...: ", label.shape)
# print(" points.shape[1] : ", points.shape[1])
# print("to_categorical(label, num_category): ", to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).shape)
# print("to_categorical(label, num_category).repeat(1, points.shape[1], 1): ", to_categorical(torch.tensor(1).cuda(), num_category).repeat(1, points.shape[1], 1).shape)
# print("input shape: ", torch.cat([points, to_categorical(torch.tensor(1).unsqueeze(dim=0).unsqueeze(dim=0).cuda(), num_category).repeat(1, points.shape[1], 1)], -1).shape)
seg_pred = classifier(torch.cat([points, to_categorical(torch.ones((points.shape[0], 1), dtype=torch.float16).cuda(), 16).repeat(1, points.shape[1], 1)], -1))
# seg_pred = classifier(torch.cat([points, to_categorical(label, num_category)], -1))
seg_pred = seg_pred.contiguous().view(-1, num_part)
target = label.view(-1, 1)[:, 0]
pred_choice = seg_pred.data.max(1)[1]
......@@ -164,11 +148,18 @@ def main(args):
correct = pred_choice.eq(target.data).cpu().sum()
mean_correct.append(correct.item() / (args.batch_size * args.num_point))
loss = criterion(seg_pred, target)
# add loss to wandb
wandb.log({'train_loss': loss.item()})
loss.backward()
optimizer.step()
train_instance_acc = np.mean(mean_correct)
logger.info('Train accuracy is: %.5f' % train_instance_acc)
# add train accuracy to wandb
wandb.log({'train_accuracy': train_instance_acc})
# add epoch to wandb
wandb.log({'epoch': epoch})
with torch.no_grad():
test_metrics = {}
......@@ -264,8 +255,15 @@ def main(args):
logger.info('Best accuracy is: %.5f' % best_acc)
logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou)
logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
wandb.log({'test_accuracy': test_metrics['accuracy']})
wandb.log({'test_class_avg_iou': test_metrics['class_avg_iou']})
wandb.log({'test_inctance_avg_iou': test_metrics['inctance_avg_iou']})
wandb.log({'best_accuracy': best_acc})
wandb.log({'best_class_avg_iou': best_class_avg_iou})
wandb.log({'best_inctance_avg_iou': best_inctance_avg_iou})
global_epoch += 1
if __name__ == '__main__':
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment