Skip to content
Snippets Groups Projects
Commit 821d7064 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

metrics in dgcnn implemented

parent d04d047b
No related branches found
No related tags found
No related merge requests found
# create a config file
training:
max_epochs: 4
max_epochs: 2
lr : 0.0001
batch_size: 4
shuffle: True
......
......@@ -5,9 +5,26 @@ from shapenet_data_dgcnn import ShapenetDataDgcnn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from model import DGCNN
from torchmetrics import Accuracy, Precision, Recall
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.strategies import DDPStrategy
class MetricsPrinterCallback(Callback):
def on_validation_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
print(f"Epoch {trainer.current_epoch}:")
print(f" val_accuracy: {metrics['val_acc']:.4f}")
print(f" val_precision: {metrics['val_precision']:.4f}")
print(f" val_recall: {metrics['val_recall']:.4f}")
def on_test_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
print(f"Epoch {trainer.current_epoch}:")
print(f" test_accuracy: {metrics['test_acc']:.4f}")
print(f" test_precision: {metrics['test_precision']:.4f}")
print(f" test_recall: {metrics['test_recall']:.4f}")
with open('config.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
......@@ -16,6 +33,19 @@ class DGCNNLightning(pl.LightningModule):
def __init__(self, num_classes):
super().__init__()
self.dgcnn = DGCNN(num_classes)
# train define metrics
self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
self.train_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro')
self.train_recall = Recall(task='multiclass', num_classes=num_classes, average='macro')
# val define metrics
self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
self.val_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro')
self.val_recall = Recall(task='multiclass', num_classes=num_classes, average='macro')
# test define metrics
self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
self.test_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro')
self.test_recall = Recall(task='multiclass', num_classes=num_classes, average='macro')
def forward(self, x):
return self.dgcnn(x)
......@@ -23,24 +53,70 @@ class DGCNNLightning(pl.LightningModule):
def training_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
pred = torch.softmax(pred, dim=1)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
self.log('train_loss', loss)
# metrics
self.log('train_loss', loss, sync_dist=True)
self.log('train_acc', self.train_accuracy(pred, class_name),sync_dist=True)
self.log('train_precision', self.train_class_precision(pred, class_name), sync_dist=True)
self.log('train_recall', self.train_recall(pred, class_name), sync_dist=True)
return loss
def training_epoch_end(self, outputs):
# logs epoch metrics
self.log('train_accurcy_epoch', self.train_accuracy.compute(), sync_dist=True)
self.log('train_precision_epoch', self.train_class_precision.compute(), sync_dist=True)
self.log('train_recall_epoch', self.train_recall.compute(), sync_dist=True)
# reset metrics
self.train_accuracy.reset()
self.train_class_precision.reset()
self.train_recall.reset()
def validation_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
pred = torch.softmax(pred, dim=1)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
self.log('val_loss', loss)
# update metrics
self.log('val_loss', loss, sync_dist=True)
self.log('val_acc', self.val_accuracy(pred, class_name), sync_dist=True)
self.log('val_precision', self.val_class_precision(pred, class_name), sync_dist=True)
self.log('val_recall', self.val_recall(pred, class_name), sync_dist=True)
return loss
def validation_epoch_end(self, outputs):
# logs epoch metrics
self.log('val_acc_epoch', self.val_accuracy.compute(), sync_dist=True)
self.log('val_precision_epoch', self.val_class_precision.compute(), sync_dist=True)
self.log('val_recall_epoch', self.val_recall.compute(), sync_dist=True)
# reset metrics
self.val_accuracy.reset()
self.val_class_precision.reset()
self.val_recall.reset()
def test_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
pred = torch.softmax(pred, dim=1)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
# update metrics
self.log('test_loss', loss)
self.log('test_acc', self.test_accuracy(pred, class_name))
self.log('test_precision', self.test_class_precision(pred, class_name))
self.log('test_recall', self.test_recall(pred, class_name))
return loss
def test_epoch_end(self, outputs):
# logs epoch metrics
self.log('test_acc_epoch', self.test_accuracy.compute())
self.log('test_precision_epoch', self.test_class_precision.compute())
self.log('test_recall_epoch', self.test_recall.compute())
# reset metrics
self.test_accuracy.reset()
self.test_class_precision.reset()
self.test_recall.reset()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
return optimizer
......@@ -120,11 +196,13 @@ if config['wandb']['use_wandb']:
logger=wandb_logger
)
else:
metrics_printer_callback = MetricsPrinterCallback()
trainer = pl.Trainer(
strategy=DDPStrategy(find_unused_parameters=True),
accelerator="auto",
devices=config['training']['devices'],
max_epochs=config['training']['max_epochs']
max_epochs=config['training']['max_epochs'],
callbacks=[metrics_printer_callback]
)
# Initialize a model
......@@ -140,3 +218,5 @@ trainer.fit(model, dataloader_train, dataloader_val)
trainer.test(model, dataloader_val)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment