Skip to content
Snippets Groups Projects
Commit 024f1f32 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

updated model of DGCNN and removed softmax

parent 2a6c22bb
No related branches found
No related tags found
No related merge requests found
...@@ -23,8 +23,8 @@ def train(): ...@@ -23,8 +23,8 @@ def train():
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=256, npoints=256,
return_cls_label=True, return_cls_label=True,
small_data=False, small_data=True,
small_data_size=1000, small_data_size=300,
just_one_class=False, just_one_class=False,
split='train', split='train',
norm=True norm=True
......
...@@ -8,12 +8,21 @@ from model import DGCNN ...@@ -8,12 +8,21 @@ from model import DGCNN
from torchmetrics import Accuracy, Precision, Recall from torchmetrics import Accuracy, Precision, Recall
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback
from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.strategies import DDPStrategy
from torch.optim.lr_scheduler import CosineAnnealingLR
class MetricsPrinterCallback(Callback): class MetricsPrinterCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
print(f"Epoch {trainer.current_epoch}:")
print(f" train_loss: {metrics['train_loss']:.4f}")
print(f" train_accuracy: {metrics['train_accurcy_epoch']:.4f}")
print(f" train_precision: {metrics['train_precision_epoch']:.4f}")
print(f" train_recall: {metrics['train_recall_epoch']:.4f}")
def on_validation_end(self, trainer, pl_module): def on_validation_end(self, trainer, pl_module):
metrics = trainer.callback_metrics metrics = trainer.callback_metrics
print(f"Epoch {trainer.current_epoch}:") print(f"Epoch {trainer.current_epoch}:")
print(f" val_loss: {metrics['val_loss']:.4f}")
print(f" val_accuracy: {metrics['val_acc']:.4f}") print(f" val_accuracy: {metrics['val_acc']:.4f}")
print(f" val_precision: {metrics['val_precision']:.4f}") print(f" val_precision: {metrics['val_precision']:.4f}")
print(f" val_recall: {metrics['val_recall']:.4f}") print(f" val_recall: {metrics['val_recall']:.4f}")
...@@ -53,8 +62,7 @@ class DGCNNLightning(pl.LightningModule): ...@@ -53,8 +62,7 @@ class DGCNNLightning(pl.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
points, _, class_name = batch points, _, class_name = batch
pred = self(points) pred = self(points)
pred = torch.softmax(pred, dim=1) loss = F.cross_entropy(pred, class_name, reduction='mean')
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
# metrics # metrics
self.log('train_loss', loss, sync_dist=True) self.log('train_loss', loss, sync_dist=True)
self.log('train_acc', self.train_accuracy(pred, class_name),sync_dist=True) self.log('train_acc', self.train_accuracy(pred, class_name),sync_dist=True)
...@@ -76,8 +84,7 @@ class DGCNNLightning(pl.LightningModule): ...@@ -76,8 +84,7 @@ class DGCNNLightning(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
points, _, class_name = batch points, _, class_name = batch
pred = self(points) pred = self(points)
pred = torch.softmax(pred, dim=1) loss = F.cross_entropy(pred, class_name, reduction='mean')
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
# update metrics # update metrics
self.log('val_loss', loss, sync_dist=True) self.log('val_loss', loss, sync_dist=True)
self.log('val_acc', self.val_accuracy(pred, class_name), sync_dist=True) self.log('val_acc', self.val_accuracy(pred, class_name), sync_dist=True)
...@@ -98,28 +105,31 @@ class DGCNNLightning(pl.LightningModule): ...@@ -98,28 +105,31 @@ class DGCNNLightning(pl.LightningModule):
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
points, _, class_name = batch points, _, class_name = batch
pred = self(points) pred = self(points)
pred = torch.softmax(pred, dim=1) loss = F.cross_entropy(pred, class_name, reduction='mean')
loss = F.cross_entropy(pred, class_name, reduction='mean', ignore_index=255)
# update metrics # update metrics
self.log('test_loss', loss) self.log('test_loss', loss)
self.log('test_acc', self.test_accuracy(pred, class_name)) self.log('test_acc', self.test_accuracy(pred, class_name), sync_dist=True)
self.log('test_precision', self.test_class_precision(pred, class_name)) self.log('test_precision', self.test_class_precision(pred, class_name), sync_dist=True)
self.log('test_recall', self.test_recall(pred, class_name)) self.log('test_recall', self.test_recall(pred, class_name), sync_dist=True)
return loss return loss
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
# logs epoch metrics # logs epoch metrics
self.log('test_acc_epoch', self.test_accuracy.compute()) self.log('test_acc_epoch', self.test_accuracy.compute(), sync_dist=True)
self.log('test_precision_epoch', self.test_class_precision.compute()) self.log('test_precision_epoch', self.test_class_precision.compute(), sync_dist=True)
self.log('test_recall_epoch', self.test_recall.compute()) self.log('test_recall_epoch', self.test_recall.compute(), sync_dist=True)
# reset metrics # reset metrics
self.test_accuracy.reset() self.test_accuracy.reset()
self.test_class_precision.reset() self.test_class_precision.reset()
self.test_recall.reset() self.test_recall.reset()
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr']) # optimizer = torch.optim.Adam(self.parameters(), lr=config['training']['lr'])
return optimizer # return optimizer
optimizer = torch.optim.SGD(self.parameters(), lr=0.1, momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"}
# get train data # get train data
shapenet_data_train = ShapenetDataDgcnn( shapenet_data_train = ShapenetDataDgcnn(
...@@ -161,7 +171,7 @@ shapenet_data_test = ShapenetDataDgcnn( ...@@ -161,7 +171,7 @@ shapenet_data_test = ShapenetDataDgcnn(
dataloader_train = torch.utils.data.DataLoader( dataloader_train = torch.utils.data.DataLoader(
shapenet_data_train, shapenet_data_train,
batch_size=config['training']['batch_size'], batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'], shuffle=True,
num_workers=config['training']['num_workers'], num_workers=config['training']['num_workers'],
drop_last=True drop_last=True
) )
...@@ -170,7 +180,7 @@ dataloader_train = torch.utils.data.DataLoader( ...@@ -170,7 +180,7 @@ dataloader_train = torch.utils.data.DataLoader(
dataloader_val = torch.utils.data.DataLoader( dataloader_val = torch.utils.data.DataLoader(
shapenet_data_val, shapenet_data_val,
batch_size=config['training']['batch_size'], batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'], shuffle=False,
num_workers=config['training']['num_workers'], num_workers=config['training']['num_workers'],
drop_last=True drop_last=True
) )
...@@ -179,7 +189,7 @@ dataloader_val = torch.utils.data.DataLoader( ...@@ -179,7 +189,7 @@ dataloader_val = torch.utils.data.DataLoader(
dataloader_test = torch.utils.data.DataLoader( dataloader_test = torch.utils.data.DataLoader(
shapenet_data_test, shapenet_data_test,
batch_size=config['training']['batch_size'], batch_size=config['training']['batch_size'],
shuffle=config['training']['shuffle'], shuffle=False,
num_workers=config['training']['num_workers'], num_workers=config['training']['num_workers'],
drop_last=True drop_last=True
) )
......
...@@ -141,10 +141,7 @@ class DGCNN(nn.Module): ...@@ -141,10 +141,7 @@ class DGCNN(nn.Module):
x7 = F.adaptive_avg_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims) x7 = F.adaptive_avg_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
# print("x7 shape: ", x7.shape) # print("x7 shape: ", x7.shape)
x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2) x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2)
# x9 = x9.max(dim=1, keepdim=False)[0]
x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512) x10 = self.linear1(x8)
x11 = self.fc(x10)
return x11
x9 = torch.max(x4, dim=1, keepdim=True)[0] \ No newline at end of file
x10 = self.fc(x9.squeeze(1))
return x10
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment