From 821d70646ced8dd883294355499ac253f0c144d5 Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Tue, 28 Mar 2023 11:40:13 +0200 Subject: [PATCH] metrics in dgcnn implemented --- dgcnn/config.yaml | 2 +- dgcnn/dgcnn_train_pl.py | 86 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/dgcnn/config.yaml b/dgcnn/config.yaml index e36f25c..d616ece 100644 --- a/dgcnn/config.yaml +++ b/dgcnn/config.yaml @@ -1,6 +1,6 @@ # create a config file training: - max_epochs: 4 + max_epochs: 2 lr : 0.0001 batch_size: 4 shuffle: True diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index 82fe4c2..3da7d3c 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -5,9 +5,26 @@ from shapenet_data_dgcnn import ShapenetDataDgcnn import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from model import DGCNN +from torchmetrics import Accuracy, Precision, Recall +from pytorch_lightning.callbacks import Callback from pytorch_lightning.strategies import DDPStrategy +class MetricsPrinterCallback(Callback): + def on_validation_end(self, trainer, pl_module): + metrics = trainer.callback_metrics + print(f"Epoch {trainer.current_epoch}:") + print(f" val_accuracy: {metrics['val_acc']:.4f}") + print(f" val_precision: {metrics['val_precision']:.4f}") + print(f" val_recall: {metrics['val_recall']:.4f}") + + def on_test_end(self, trainer, pl_module): + metrics = trainer.callback_metrics + print(f"Epoch {trainer.current_epoch}:") + print(f" test_accuracy: {metrics['test_acc']:.4f}") + print(f" test_precision: {metrics['test_precision']:.4f}") + print(f" test_recall: {metrics['test_recall']:.4f}") + with open('config.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -16,6 +33,19 @@ class DGCNNLightning(pl.LightningModule): def __init__(self, num_classes): super().__init__() self.dgcnn = DGCNN(num_classes) + # train define metrics + self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes) + self.train_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro') + self.train_recall = Recall(task='multiclass', num_classes=num_classes, average='macro') + # val define metrics + self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes) + self.val_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro') + self.val_recall = Recall(task='multiclass', num_classes=num_classes, average='macro') + # test define metrics + self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes) + self.test_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro') + self.test_recall = Recall(task='multiclass', num_classes=num_classes, average='macro') + def forward(self, x): return self.dgcnn(x) @@ -23,24 +53,70 @@ class DGCNNLightning(pl.LightningModule): def training_step(self, batch, batch_idx): points, _, class_name = batch pred = self(points) + pred = torch.softmax(pred, dim=1) loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) - self.log('train_loss', loss) + # metrics + self.log('train_loss', loss, sync_dist=True) + self.log('train_acc', self.train_accuracy(pred, class_name),sync_dist=True) + self.log('train_precision', self.train_class_precision(pred, class_name), sync_dist=True) + self.log('train_recall', self.train_recall(pred, class_name), sync_dist=True) + return loss + def training_epoch_end(self, outputs): + # logs epoch metrics + self.log('train_accurcy_epoch', self.train_accuracy.compute(), sync_dist=True) + self.log('train_precision_epoch', self.train_class_precision.compute(), sync_dist=True) + self.log('train_recall_epoch', self.train_recall.compute(), sync_dist=True) + # reset metrics + self.train_accuracy.reset() + self.train_class_precision.reset() + self.train_recall.reset() + def validation_step(self, batch, batch_idx): points, _, class_name = batch pred = self(points) + pred = torch.softmax(pred, dim=1) loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) - self.log('val_loss', loss) + # update metrics + self.log('val_loss', loss, sync_dist=True) + self.log('val_acc', self.val_accuracy(pred, class_name), sync_dist=True) + self.log('val_precision', self.val_class_precision(pred, class_name), sync_dist=True) + self.log('val_recall', self.val_recall(pred, class_name), sync_dist=True) return loss + def validation_epoch_end(self, outputs): + # logs epoch metrics + self.log('val_acc_epoch', self.val_accuracy.compute(), sync_dist=True) + self.log('val_precision_epoch', self.val_class_precision.compute(), sync_dist=True) + self.log('val_recall_epoch', self.val_recall.compute(), sync_dist=True) + # reset metrics + self.val_accuracy.reset() + self.val_class_precision.reset() + self.val_recall.reset() + def test_step(self, batch, batch_idx): points, _, class_name = batch pred = self(points) + pred = torch.softmax(pred, dim=1) loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) + # update metrics self.log('test_loss', loss) + self.log('test_acc', self.test_accuracy(pred, class_name)) + self.log('test_precision', self.test_class_precision(pred, class_name)) + self.log('test_recall', self.test_recall(pred, class_name)) return loss + def test_epoch_end(self, outputs): + # logs epoch metrics + self.log('test_acc_epoch', self.test_accuracy.compute()) + self.log('test_precision_epoch', self.test_class_precision.compute()) + self.log('test_recall_epoch', self.test_recall.compute()) + # reset metrics + self.test_accuracy.reset() + self.test_class_precision.reset() + self.test_recall.reset() + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) return optimizer @@ -120,11 +196,13 @@ if config['wandb']['use_wandb']: logger=wandb_logger ) else: + metrics_printer_callback = MetricsPrinterCallback() trainer = pl.Trainer( strategy=DDPStrategy(find_unused_parameters=True), accelerator="auto", devices=config['training']['devices'], - max_epochs=config['training']['max_epochs'] + max_epochs=config['training']['max_epochs'], + callbacks=[metrics_printer_callback] ) # Initialize a model @@ -140,3 +218,5 @@ trainer.fit(model, dataloader_train, dataloader_val) trainer.test(model, dataloader_val) + + -- GitLab