diff --git a/dgcnn/dgcnn_train.py b/dgcnn/dgcnn_train.py index 59f502ea51d0107b5216bdc6d05f7498f03bcefb..132972b8a3d7ed92815240cbea3bbd631fea209d 100644 --- a/dgcnn/dgcnn_train.py +++ b/dgcnn/dgcnn_train.py @@ -23,8 +23,8 @@ def train(): root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', npoints=256, return_cls_label=True, - small_data=False, - small_data_size=1000, + small_data=True, + small_data_size=300, just_one_class=False, split='train', norm=True diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index 6a4bae7b55343cafec5613ed3e37a9e5b741e678..67b04b9dfa781cdc97d87e50528183f47040a169 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -8,12 +8,21 @@ from model import DGCNN from torchmetrics import Accuracy, Precision, Recall from pytorch_lightning.callbacks import Callback from pytorch_lightning.strategies import DDPStrategy - +from torch.optim.lr_scheduler import CosineAnnealingLR class MetricsPrinterCallback(Callback): + def on_train_epoch_end(self, trainer, pl_module): + metrics = trainer.callback_metrics + print(f"Epoch {trainer.current_epoch}:") + print(f" train_loss: {metrics['train_loss']:.4f}") + print(f" train_accuracy: {metrics['train_accurcy_epoch']:.4f}") + print(f" train_precision: {metrics['train_precision_epoch']:.4f}") + print(f" train_recall: {metrics['train_recall_epoch']:.4f}") + def on_validation_end(self, trainer, pl_module): metrics = trainer.callback_metrics print(f"Epoch {trainer.current_epoch}:") + print(f" val_loss: {metrics['val_loss']:.4f}") print(f" val_accuracy: {metrics['val_acc']:.4f}") print(f" val_precision: {metrics['val_precision']:.4f}") print(f" val_recall: {metrics['val_recall']:.4f}") @@ -53,8 +62,7 @@ class DGCNNLightning(pl.LightningModule): def training_step(self, batch, batch_idx): points, _, class_name = batch pred = self(points) - pred = torch.softmax(pred, dim=1) - loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) + loss = F.cross_entropy(pred, class_name, reduction='mean') # metrics self.log('train_loss', loss, sync_dist=True) self.log('train_acc', self.train_accuracy(pred, class_name),sync_dist=True) @@ -76,8 +84,7 @@ class DGCNNLightning(pl.LightningModule): def validation_step(self, batch, batch_idx): points, _, class_name = batch pred = self(points) - pred = torch.softmax(pred, dim=1) - loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) + loss = F.cross_entropy(pred, class_name, reduction='mean') # update metrics self.log('val_loss', loss, sync_dist=True) self.log('val_acc', self.val_accuracy(pred, class_name), sync_dist=True) @@ -98,28 +105,31 @@ class DGCNNLightning(pl.LightningModule): def test_step(self, batch, batch_idx): points, _, class_name = batch pred = self(points) - pred = torch.softmax(pred, dim=1) - loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255) + loss = F.cross_entropy(pred, class_name, reduction='mean') # update metrics self.log('test_loss', loss) - self.log('test_acc', self.test_accuracy(pred, class_name)) - self.log('test_precision', self.test_class_precision(pred, class_name)) - self.log('test_recall', self.test_recall(pred, class_name)) + self.log('test_acc', self.test_accuracy(pred, class_name), sync_dist=True) + self.log('test_precision', self.test_class_precision(pred, class_name), sync_dist=True) + self.log('test_recall', self.test_recall(pred, class_name), sync_dist=True) return loss def test_epoch_end(self, outputs): # logs epoch metrics - self.log('test_acc_epoch', self.test_accuracy.compute()) - self.log('test_precision_epoch', self.test_class_precision.compute()) - self.log('test_recall_epoch', self.test_recall.compute()) + self.log('test_acc_epoch', self.test_accuracy.compute(), sync_dist=True) + self.log('test_precision_epoch', self.test_class_precision.compute(), sync_dist=True) + self.log('test_recall_epoch', self.test_recall.compute(), sync_dist=True) # reset metrics self.test_accuracy.reset() self.test_class_precision.reset() self.test_recall.reset() def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) - return optimizer + # optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) + # return optimizer + optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9) + scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001) + return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"} + # get train data shapenet_data_train = ShapenetDataDgcnn( @@ -161,7 +171,7 @@ shapenet_data_test = ShapenetDataDgcnn( dataloader_train = torch.utils.data.DataLoader( shapenet_data_train, batch_size=config['training']['batch_size'], - shuffle=config['training']['shuffle'], + shuffle=True, num_workers=config['training']['num_workers'], drop_last=True ) @@ -170,7 +180,7 @@ dataloader_train = torch.utils.data.DataLoader( dataloader_val = torch.utils.data.DataLoader( shapenet_data_val, batch_size=config['training']['batch_size'], - shuffle=config['training']['shuffle'], + shuffle=False, num_workers=config['training']['num_workers'], drop_last=True ) @@ -179,7 +189,7 @@ dataloader_val = torch.utils.data.DataLoader( dataloader_test = torch.utils.data.DataLoader( shapenet_data_test, batch_size=config['training']['batch_size'], - shuffle=config['training']['shuffle'], + shuffle=False, num_workers=config['training']['num_workers'], drop_last=True ) diff --git a/dgcnn/model.py b/dgcnn/model.py index f90b1e2fd2e36b42f8852af028030b5736ec9504..b3dae823428b84f5329fc676fb94deb93c8656b0 100644 --- a/dgcnn/model.py +++ b/dgcnn/model.py @@ -141,10 +141,7 @@ class DGCNN(nn.Module): x7 = F.adaptive_avg_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims) # print("x7 shape: ", x7.shape) x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2) - - x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512) - - - x9 = torch.max(x4, dim=1, keepdim=True)[0] - x10 = self.fc(x9.squeeze(1)) - return x10 \ No newline at end of file + # x9 = x9.max(dim=1, keepdim=False)[0] + x10 = self.linear1(x8) + x11 = self.fc(x10) + return x11 \ No newline at end of file