Skip to content
Snippets Groups Projects
Commit f1468350 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

updated model for 1 - 4 categories shapenet segmentation

parent 84d50a5b
No related branches found
No related tags found
No related merge requests found
......@@ -5,14 +5,12 @@ import torch.nn.init as init
from my_models.model_shape_net import DgcnShapeNet as DGCNN
def main():
dgcnn = DGCNN(50)
dgcnn = DGCNN(50, 4)
dgcnn.eval()
# simple test
input_tensor = torch.randn(2, 128, 3)
label_one_hot = torch.zeros((2, 16))
label_one_hot = torch.zeros((2, 4))
print(input_tensor.shape)
out = dgcnn(input_tensor, label_one_hot)
print(out.shape)
......
......@@ -138,7 +138,7 @@ shapenet_data_train = ShapenetDataDgcnn(
return_cls_label=True,
small_data=config['data']['small_data'],
small_data_size=config['data']['small_data_size'],
num_classes=config['data']['just_four_classes'],
num_classes=config['data']['num_classes'],
split='train',
norm=config['data']['norm'],
augmnetation=config['data']['augmentation']
......@@ -151,7 +151,7 @@ shapenet_data_val = ShapenetDataDgcnn(
return_cls_label=True,
small_data=config['data']['small_data'],
small_data_size=config['data']['small_data_size'],
num_classes=config['data']['just_four_classes'],
num_classes=config['data']['num_classes'],
split='test',
norm=config['data']['norm']
)
......@@ -163,7 +163,7 @@ shapenet_data_test = ShapenetDataDgcnn(
return_cls_label=True,
small_data=config['data']['small_data'],
small_data_size=config['data']['small_data_size'],
num_classes=config['data']['just_four_classes'],
num_classes=config['data']['num_classes'],
split='test',
norm=config['data']['norm']
)
......@@ -217,10 +217,8 @@ else:
)
# Initialize a model
if config['data']['just_four_classes']:
num_classes = 4
else:
num_classes = 16
num_classes = int(config['data']['num_classes'])
model = DGCNNLightning(num_classes=num_classes)
if config['wandb']['use_wandb']:
......
......@@ -31,12 +31,12 @@ def train():
# get data
shapenet_data = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=1024,
npoints=32,
return_cls_label=True,
small_data=False,
small_data_size=300,
just_four_classes=False,
data_augmentation=True,
small_data=True,
small_data_size=10,
num_classes=True,
data_augmentation=False,
split='train',
norm=True
)
......@@ -44,7 +44,7 @@ def train():
# create a dataloader
dataloader = torch.utils.data.DataLoader(
shapenet_data,
batch_size=4,
batch_size=8,
shuffle=True,
num_workers=8,
drop_last=True
......@@ -77,6 +77,11 @@ def train():
# print(f"Batch: {i}")
points, labels, class_name = data
# log data to wandb
if use_wandb:
wandb.log({"points": wandb.Object3D(points[0, :, :].cpu().numpy())})
# wandb.log({"labels": labels.cpu().numpy()})
# wandb.log({"class_name": class_name.cpu().numpy()})
label_one_hot = np.zeros((class_name.shape[0], 16))
for idx in range(class_name.shape[0]):
label_one_hot[idx, class_name[idx]] = 1
......
import random
import numpy as np
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from matplotlib import pyplot as plt
from shapenet_data_dgcnn import ShapenetDataDgcnn
# import IoU from torchmetrics
from torchmetrics import Accuracy
from torchmetrics import JaccardIndex as IoU
from torch.optim.lr_scheduler import StepLR
import wandb
use_wandb = False
num_classes = 4
batch_size = 4
# create a wandb run
if use_wandb:
run = wandb.init(project="dgcnn", entity="maciej-wielgosz-nibio")
from my_models.model_shape_net import DgcnShapeNet as DGCNN
def train():
seg_num_all = 50
dgcnn = DGCNN(
seg_num_all=seg_num_all,
num_classes=num_classes
).cuda()
dgcnn.train()
# get data
shapenet_data = ShapenetDataDgcnn(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
npoints=32,
return_cls_label=True,
small_data=True,
small_data_size=10,
num_classes=num_classes,
data_augmentation=False,
split='train',
norm=True
)
# create a dataloader
dataloader = torch.utils.data.DataLoader(
shapenet_data,
batch_size=batch_size,
shuffle=True,
num_workers=8,
drop_last=True
)
# create a optimizer
optimizer = torch.optim.Adam(dgcnn.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)
if use_wandb:
# create a config wandb
wandb.config.update({
"batch_size": batch_size,
"learning_rate": 0.01,
"optimizer": "Adam",
"loss_function": "cross_entropy"
})
# train
iou = IoU(num_classes=50, task='multiclass', average='macro').cuda()
acc = Accuracy(num_classes=50, compute_on_step=False, dist_sync_on_step=False, task='multiclass').cuda()
for epoch in range(500):
iou.reset()
print(f"Epoch: {epoch}")
if use_wandb:
wandb.log({"epoch": epoch})
for i, data in enumerate(dataloader, 0):
# print(f"Batch: {i}")
points, labels, class_name = data
print('class_name', class_name)
label_one_hot = np.zeros((class_name.shape[0], num_classes))
for idx in range(class_name.shape[0]):
label_one_hot[idx, class_name[idx]] = 1
label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
points = points.cuda()
labels = labels.cuda()
label_one_hot = label_one_hot.cuda()
optimizer.zero_grad()
pred = dgcnn(points, label_one_hot)
loss = F.cross_entropy(
pred, labels, reduction='mean', ignore_index=255)
loss.backward()
optimizer.step()
if optimizer.param_groups[0]['lr'] > 1e-5:
scheduler.step()
# print lose every 10 batches
if i % 100 == 0:
print(loss.item())
pred_softmax = F.softmax(pred, dim=1)
pred_argmax = torch.argmax(pred_softmax, dim=1)
gt_one_hot = F.one_hot(labels, num_classes=50)
pred_one_hot = F.one_hot(pred_argmax, num_classes=50)
print("loss : ", loss.item())
print("IoU : ", iou(pred_one_hot, gt_one_hot))
print("Acc : ", acc(pred_argmax, labels))
if use_wandb:
wandb.log({"loss": loss.item()})
wandb.log({"iou": iou(pred, labels)})
wandb.log({"acc": acc(pred, labels)})
if __name__ == '__main__':
train()
\ No newline at end of file
......@@ -43,7 +43,7 @@ class DgcnShapeNet(nn.Module):
def forward(self, x, class_label):
def forward(self, x, class_label_one_hot):
# Apply Transform_Net on input point cloud
trans_matrix = self.transform_net(x)
......@@ -63,10 +63,10 @@ class DgcnShapeNet(nn.Module):
x = self.conv5(x5) # (batch_size, 1024, num_points)
x = x.max(dim=-1, keepdim=True)[0] # (batch_size, 1024)
class_label = class_label.view(batch_size, -1, 1) # (batch_size, num_categoties, 1)
class_label = self.conv7(class_label) # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)
class_label_one_hot = class_label_one_hot.view(batch_size, -1, 1) # (batch_size, num_categoties, 1)
class_label_one_hot = self.conv7(class_label_one_hot) # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)
x = torch.cat((x, class_label), dim=1) # (batch_size, 1088, 1)
x = torch.cat((x, class_label_one_hot), dim=1) # (batch_size, 1088, 1)
x = x.repeat(1, 1, num_points) # (batch_size, 1088, num_points)
x = torch.cat((x, x1, x2), dim=1) # (batch_size, 1088+64*3, num_points)
......
......@@ -17,7 +17,7 @@ class ShapenetDataDgcnn(object):
small_data=False,
small_data_size=10,
return_cls_label=False,
num_classes=1, # None - all classes (50), 1 - one class, 2 - two classes, max 4
num_classes=1, # you can choose 1, 2, 3, 4, or 16
norm=False,
augmnetation=False,
data_augmentation=False
......@@ -241,12 +241,15 @@ class ShapenetDataDgcnn(object):
class_name = self.val_data_file[index].split('/')[-2]
# apply the mapper
# if self.num_classes:
# class_name = self.class_mapper_4_classes(class_name)
# else:
# class_name = self.class_mapper(class_name)
if self.num_classes in range(1, 5):
class_name = self.class_mapper_4_classes(class_name)
elif self.num_classes == 16:
class_name = self.class_mapper(class_name)
else:
raise ValueError('num_classes not in range, should be in range 1-4 or 16')
class_name = self.class_mapper(class_name)
# class_name = self.class_mapper(class_name)
# convert the class name to a number
class_name = np.array(class_name, dtype=np.int64)
......@@ -254,6 +257,7 @@ class ShapenetDataDgcnn(object):
# map to tensor
# class_name = torch.from_numpy(class_name)
if self.return_cls_label:
return point_set, labels, class_name
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment