Skip to content
Snippets Groups Projects
Commit 024f1f32 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

updated model of DGCNN and removed softmax

parent 2a6c22bb
Branches
No related tags found
No related merge requests found
......@@ -23,8 +23,8 @@ def train():
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=256,
return_cls_label=True,
small_data=False,
small_data_size=1000,
small_data=True,
small_data_size=300,
just_one_class=False,
split='train',
norm=True
......
......@@ -8,12 +8,21 @@ from model import DGCNN
from torchmetrics import Accuracy, Precision, Recall
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.strategies import DDPStrategy
from torch.optim.lr_scheduler import CosineAnnealingLR
class MetricsPrinterCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
print(f"Epoch {trainer.current_epoch}:")
print(f" train_loss: {metrics['train_loss']:.4f}")
print(f" train_accuracy: {metrics['train_accurcy_epoch']:.4f}")
print(f" train_precision: {metrics['train_precision_epoch']:.4f}")
print(f" train_recall: {metrics['train_recall_epoch']:.4f}")
def on_validation_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
print(f"Epoch {trainer.current_epoch}:")
print(f" val_loss: {metrics['val_loss']:.4f}")
print(f" val_accuracy: {metrics['val_acc']:.4f}")
print(f" val_precision: {metrics['val_precision']:.4f}")
print(f" val_recall: {metrics['val_recall']:.4f}")
......@@ -53,8 +62,7 @@ class DGCNNLightning(pl.LightningModule):
def training_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
pred = torch.softmax(pred, dim=1)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
loss = F.cross_entropy(pred, class_name, reduction='mean')
# metrics
self.log('train_loss', loss, sync_dist=True)
self.log('train_acc', self.train_accuracy(pred, class_name),sync_dist=True)
......@@ -76,8 +84,7 @@ class DGCNNLightning(pl.LightningModule):
def validation_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
pred = torch.softmax(pred, dim=1)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
loss = F.cross_entropy(pred, class_name, reduction='mean')
# update metrics
self.log('val_loss', loss, sync_dist=True)
self.log('val_acc', self.val_accuracy(pred, class_name), sync_dist=True)
......@@ -98,28 +105,31 @@ class DGCNNLightning(pl.LightningModule):
def test_step(self, batch, batch_idx):
points, _, class_name = batch
pred = self(points)
pred = torch.softmax(pred, dim=1)
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
loss = F.cross_entropy(pred, class_name, reduction='mean')
# update metrics
self.log('test_loss', loss)
self.log('test_acc', self.test_accuracy(pred, class_name))
self.log('test_precision', self.test_class_precision(pred, class_name))
self.log('test_recall', self.test_recall(pred, class_name))
self.log('test_acc', self.test_accuracy(pred, class_name), sync_dist=True)
self.log('test_precision', self.test_class_precision(pred, class_name), sync_dist=True)
self.log('test_recall', self.test_recall(pred, class_name), sync_dist=True)
return loss
def test_epoch_end(self, outputs):
# logs epoch metrics
self.log('test_acc_epoch', self.test_accuracy.compute())
self.log('test_precision_epoch', self.test_class_precision.compute())
self.log('test_recall_epoch', self.test_recall.compute())
self.log('test_acc_epoch', self.test_accuracy.compute(), sync_dist=True)
self.log('test_precision_epoch', self.test_class_precision.compute(), sync_dist=True)
self.log('test_recall_epoch', self.test_recall.compute(), sync_dist=True)
# reset metrics
self.test_accuracy.reset()
self.test_class_precision.reset()
self.test_recall.reset()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
return optimizer
# optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
# return optimizer
optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"}
# get train data
shapenet_data_train = ShapenetDataDgcnn(
......@@ -161,7 +171,7 @@ shapenet_data_test = ShapenetDataDgcnn(
dataloader_train = torch.utils.data.DataLoader(
shapenet_data_train,
batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'],
shuffle=True,
num_workers=config['training']['num_workers'],
drop_last=True
)
......@@ -170,7 +180,7 @@ dataloader_train = torch.utils.data.DataLoader(
dataloader_val = torch.utils.data.DataLoader(
shapenet_data_val,
batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'],
shuffle=False,
num_workers=config['training']['num_workers'],
drop_last=True
)
......@@ -179,7 +189,7 @@ dataloader_val = torch.utils.data.DataLoader(
dataloader_test = torch.utils.data.DataLoader(
shapenet_data_test,
batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'],
shuffle=False,
num_workers=config['training']['num_workers'],
drop_last=True
)
......
......@@ -141,10 +141,7 @@ class DGCNN(nn.Module):
x7 = F.adaptive_avg_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
# print("x7 shape: ", x7.shape)
x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2)
x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
x9 = torch.max(x4, dim=1, keepdim=True)[0]
x10 = self.fc(x9.squeeze(1))
return x10
\ No newline at end of file
# x9 = x9.max(dim=1, keepdim=False)[0]
x10 = self.linear1(x8)
x11 = self.fc(x10)
return x11
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment