diff --git a/dgcnn/config.yaml b/dgcnn/config.yaml index e36f25c6163b8cdedf985c2137fac046e000eb3e..d616ecea4010f2b0c8837d556f0f0b66ded7aea3 100644 --- a/dgcnn/config.yaml +++ b/dgcnn/config.yaml @@ -1,6 +1,6 @@ # create a config file training: - max_epochs: 4 + max_epochs: 2 lr : 0.0001 batch_size: 4 shuffle: True diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index 82fe4c2c2ef2df08a9436f1b8391d8bad32b4621..3da7d3c04fb90b785dd40551e9fb93b1e64aae86 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -5,9 +5,26 @@ from shapenet_data_dgcnn import ShapenetDataDgcnn import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from model import DGCNN +from torchmetrics import Accuracy, Precision, Recall +from pytorch_lightning.callbacks import Callback from pytorch_lightning.strategies import DDPStrategy +class MetricsPrinterCallback(Callback): + def on_validation_end(self, trainer, pl_module): + metrics = trainer.callback_metrics + print(f"Epoch {trainer.current_epoch}:") + print(f" val_accuracy: {metrics['val_acc']:.4f}") + print(f" val_precision: {metrics['val_precision']:.4f}") + print(f" val_recall: {metrics['val_recall']:.4f}") + + def on_test_end(self, trainer, pl_module): + metrics = trainer.callback_metrics + print(f"Epoch {trainer.current_epoch}:") + print(f" test_accuracy: {metrics['test_acc']:.4f}") + print(f" test_precision: {metrics['test_precision']:.4f}") + print(f" test_recall: {metrics['test_recall']:.4f}") + with open('config.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -16,6 +33,19 @@ class DGCNNLightning(pl.LightningModule): def __init__(self, num_classes): super().__init__() self.dgcnn = DGCNN(num_classes) + # train define metrics + self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes) + self.train_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro') + self.train_recall = Recall(task='multiclass', num_classes=num_classes, average='macro') + # val define metrics + self.val_accuracy = Accuracy(task='multiclass', num_classes=num_classes) + self.val_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro') + self.val_recall = Recall(task='multiclass', num_classes=num_classes, average='macro') + # test define metrics + self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes) + self.test_class_precision = Precision(task='multiclass', num_classes=num_classes, average='macro') + self.test_recall = Recall(task='multiclass', num_classes=num_classes, average='macro') + def forward(self, x): return self.dgcnn(x) @@ -23,24 +53,70 @@ class DGCNNLightning(pl.LightningModule): def training_step(self, batch, batch_idx): points, _, class_name = batch pred = self(points) + pred = torch.softmax(pred, dim=1) loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) - self.log('train_loss', loss) + # metrics + self.log('train_loss', loss, sync_dist=True) + self.log('train_acc', self.train_accuracy(pred, class_name),sync_dist=True) + self.log('train_precision', self.train_class_precision(pred, class_name), sync_dist=True) + self.log('train_recall', self.train_recall(pred, class_name), sync_dist=True) + return loss + def training_epoch_end(self, outputs): + # logs epoch metrics + self.log('train_accurcy_epoch', self.train_accuracy.compute(), sync_dist=True) + self.log('train_precision_epoch', self.train_class_precision.compute(), sync_dist=True) + self.log('train_recall_epoch', self.train_recall.compute(), sync_dist=True) + # reset metrics + self.train_accuracy.reset() + self.train_class_precision.reset() + self.train_recall.reset() + def validation_step(self, batch, batch_idx): points, _, class_name = batch pred = self(points) + pred = torch.softmax(pred, dim=1) loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) - self.log('val_loss', loss) + # update metrics + self.log('val_loss', loss, sync_dist=True) + self.log('val_acc', self.val_accuracy(pred, class_name), sync_dist=True) + self.log('val_precision', self.val_class_precision(pred, class_name), sync_dist=True) + self.log('val_recall', self.val_recall(pred, class_name), sync_dist=True) return loss + def validation_epoch_end(self, outputs): + # logs epoch metrics + self.log('val_acc_epoch', self.val_accuracy.compute(), sync_dist=True) + self.log('val_precision_epoch', self.val_class_precision.compute(), sync_dist=True) + self.log('val_recall_epoch', self.val_recall.compute(), sync_dist=True) + # reset metrics + self.val_accuracy.reset() + self.val_class_precision.reset() + self.val_recall.reset() + def test_step(self, batch, batch_idx): points, _, class_name = batch pred = self(points) + pred = torch.softmax(pred, dim=1) loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) + # update metrics self.log('test_loss', loss) + self.log('test_acc', self.test_accuracy(pred, class_name)) + self.log('test_precision', self.test_class_precision(pred, class_name)) + self.log('test_recall', self.test_recall(pred, class_name)) return loss + def test_epoch_end(self, outputs): + # logs epoch metrics + self.log('test_acc_epoch', self.test_accuracy.compute()) + self.log('test_precision_epoch', self.test_class_precision.compute()) + self.log('test_recall_epoch', self.test_recall.compute()) + # reset metrics + self.test_accuracy.reset() + self.test_class_precision.reset() + self.test_recall.reset() + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) return optimizer @@ -120,11 +196,13 @@ if config['wandb']['use_wandb']: logger=wandb_logger ) else: + metrics_printer_callback = MetricsPrinterCallback() trainer = pl.Trainer( strategy=DDPStrategy(find_unused_parameters=True), accelerator="auto", devices=config['training']['devices'], - max_epochs=config['training']['max_epochs'] + max_epochs=config['training']['max_epochs'], + callbacks=[metrics_printer_callback] ) # Initialize a model @@ -140,3 +218,5 @@ trainer.fit(model, dataloader_train, dataloader_val) trainer.test(model, dataloader_val) + +