Skip to content
Snippets Groups Projects
Commit d04d047b authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

update with val and test - dgcnn

parent 865078f8
Branches
No related tags found
No related merge requests found
# create a config file
training:
max_epochs: 2
max_epochs: 4
lr : 0.0001
batch_size: 4
shuffle: True
......@@ -19,9 +19,10 @@ data:
norm: true
wandb:
use_wandb: false
use_wandb: true
watch_model : false
project: dgcnn
name: dgcnn
name: dgcnn-train-val-test
entity: maciej-wielgosz-nibio
......@@ -27,11 +27,25 @@ class DGCNNLightning(pl.LightningModule):
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
self.log('val_loss', loss)
return loss
def test_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
self.log('test_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
return optimizer
# get data
# get train data
shapenet_data_train = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=config['data']['npoints'],
......@@ -43,7 +57,31 @@ shapenet_data_train = ShapenetDataDgcnn(
norm=config['data']['norm']
)
# create a dataloader
# get val data
shapenet_data_val = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=config['data']['npoints'],
return_cls_label=True,
small_data=config['data']['small_data'],
small_data_size=config['data']['small_data_size'],
just_one_class=config['data']['just_one_class'],
split='val',
norm=config['data']['norm']
)
# get test data
shapenet_data_test = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=config['data']['npoints'],
return_cls_label=True,
small_data=config['data']['small_data'],
small_data_size=config['data']['small_data_size'],
just_one_class=config['data']['just_one_class'],
split='test',
norm=config['data']['norm']
)
# create train dataloader
dataloader_train = torch.utils.data.DataLoader(
shapenet_data_train,
batch_size=config['training']['batch_size'],
......@@ -52,6 +90,24 @@ dataloader_train = torch.utils.data.DataLoader(
drop_last=True
)
# create val dataloader
dataloader_val = torch.utils.data.DataLoader(
shapenet_data_val,
batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'],
num_workers=config['training']['num_workers'],
drop_last=True
)
# create test dataloader
dataloader_test = torch.utils.data.DataLoader(
shapenet_data_test,
batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'],
num_workers=config['training']['num_workers'],
drop_last=True
)
# Initialize a trainer
if config['wandb']['use_wandb']:
......@@ -59,7 +115,7 @@ if config['wandb']['use_wandb']:
trainer = pl.Trainer(
strategy=DDPStrategy(find_unused_parameters=True),
accelerator="auto",
devices=config['wandb']['devices'],
devices=config['training']['devices'],
max_epochs=config['training']['max_epochs'],
logger=wandb_logger
)
......@@ -67,15 +123,20 @@ else:
trainer = pl.Trainer(
strategy=DDPStrategy(find_unused_parameters=True),
accelerator="auto",
devices=[0],
max_epochs=3
devices=config['training']['devices'],
max_epochs=config['training']['max_epochs']
)
# Initialize a model
model = DGCNNLightning(num_classes=16)
if config['wandb']['use_wandb']:
wandb_logger.watch(model)
# Train the model on gpu
trainer.fit(model, dataloader_train)
if config['wandb']['watch_model']:
wandb_logger.watch(model)
# Train the model on gpu and validate every epoch
trainer.fit(model, dataloader_train, dataloader_val)
# Test the model on gpu
trainer.test(model, dataloader_val)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment