Skip to content
Snippets Groups Projects
Commit d04d047b authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

update with val and test - dgcnn

parent 865078f8
No related branches found
No related tags found
No related merge requests found
# create a config file # create a config file
training: training:
max_epochs: 2 max_epochs: 4
lr : 0.0001 lr : 0.0001
batch_size: 4 batch_size: 4
shuffle: True shuffle: True
...@@ -19,9 +19,10 @@ data: ...@@ -19,9 +19,10 @@ data:
norm: true norm: true
wandb: wandb:
use_wandb: false use_wandb: true
watch_model : false
project: dgcnn project: dgcnn
name: dgcnn name: dgcnn-train-val-test
entity: maciej-wielgosz-nibio entity: maciej-wielgosz-nibio
...@@ -27,11 +27,25 @@ class DGCNNLightning(pl.LightningModule): ...@@ -27,11 +27,25 @@ class DGCNNLightning(pl.LightningModule):
self.log('train_loss', loss) self.log('train_loss', loss)
return loss return loss
def validation_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
self.log('val_loss', loss)
return loss
def test_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
self.log('test_loss', loss)
return loss
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
return optimizer return optimizer
# get data # get train data
shapenet_data_train = ShapenetDataDgcnn( shapenet_data_train = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=config['data']['npoints'], npoints=config['data']['npoints'],
...@@ -43,7 +57,31 @@ shapenet_data_train = ShapenetDataDgcnn( ...@@ -43,7 +57,31 @@ shapenet_data_train = ShapenetDataDgcnn(
norm=config['data']['norm'] norm=config['data']['norm']
) )
# create a dataloader # get val data
shapenet_data_val = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=config['data']['npoints'],
return_cls_label=True,
small_data=config['data']['small_data'],
small_data_size=config['data']['small_data_size'],
just_one_class=config['data']['just_one_class'],
split='val',
norm=config['data']['norm']
)
# get test data
shapenet_data_test = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=config['data']['npoints'],
return_cls_label=True,
small_data=config['data']['small_data'],
small_data_size=config['data']['small_data_size'],
just_one_class=config['data']['just_one_class'],
split='test',
norm=config['data']['norm']
)
# create train dataloader
dataloader_train = torch.utils.data.DataLoader( dataloader_train = torch.utils.data.DataLoader(
shapenet_data_train, shapenet_data_train,
batch_size=config['training']['batch_size'], batch_size=config['training']['batch_size'],
...@@ -52,6 +90,24 @@ dataloader_train = torch.utils.data.DataLoader( ...@@ -52,6 +90,24 @@ dataloader_train = torch.utils.data.DataLoader(
drop_last=True drop_last=True
) )
# create val dataloader
dataloader_val = torch.utils.data.DataLoader(
shapenet_data_val,
batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'],
num_workers=config['training']['num_workers'],
drop_last=True
)
# create test dataloader
dataloader_test = torch.utils.data.DataLoader(
shapenet_data_test,
batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'],
num_workers=config['training']['num_workers'],
drop_last=True
)
# Initialize a trainer # Initialize a trainer
if config['wandb']['use_wandb']: if config['wandb']['use_wandb']:
...@@ -59,7 +115,7 @@ if config['wandb']['use_wandb']: ...@@ -59,7 +115,7 @@ if config['wandb']['use_wandb']:
trainer = pl.Trainer( trainer = pl.Trainer(
strategy=DDPStrategy(find_unused_parameters=True), strategy=DDPStrategy(find_unused_parameters=True),
accelerator="auto", accelerator="auto",
devices=config['wandb']['devices'], devices=config['training']['devices'],
max_epochs=config['training']['max_epochs'], max_epochs=config['training']['max_epochs'],
logger=wandb_logger logger=wandb_logger
) )
...@@ -67,15 +123,20 @@ else: ...@@ -67,15 +123,20 @@ else:
trainer = pl.Trainer( trainer = pl.Trainer(
strategy=DDPStrategy(find_unused_parameters=True), strategy=DDPStrategy(find_unused_parameters=True),
accelerator="auto", accelerator="auto",
devices=[0], devices=config['training']['devices'],
max_epochs=3 max_epochs=config['training']['max_epochs']
) )
# Initialize a model # Initialize a model
model = DGCNNLightning(num_classes=16) model = DGCNNLightning(num_classes=16)
if config['wandb']['use_wandb']: if config['wandb']['use_wandb']:
wandb_logger.watch(model) if config['wandb']['watch_model']:
# Train the model on gpu wandb_logger.watch(model)
trainer.fit(model, dataloader_train) # Train the model on gpu and validate every epoch
trainer.fit(model, dataloader_train, dataloader_val)
# Test the model on gpu
trainer.test(model, dataloader_val)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment