Skip to content
Snippets Groups Projects
Commit 584299f5 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

bug fix in the data loader

parent 0e069d09
Branches
No related tags found
No related merge requests found
......@@ -141,7 +141,7 @@ shapenet_data_train = ShapenetDataDgcnn(
just_one_class=config['data']['just_one_class'],
split='train',
norm=config['data']['norm'],
augmnetation=True
augmnetation=config['data']['augmentation']
)
# get val data
......@@ -152,7 +152,7 @@ shapenet_data_val = ShapenetDataDgcnn(
small_data=config['data']['small_data'],
small_data_size=config['data']['small_data_size'],
just_one_class=config['data']['just_one_class'],
split='train',
split='test',
norm=config['data']['norm']
)
......
......@@ -5,7 +5,48 @@ import torch.nn.init as init
# TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166
# TODO: There are mode conv layers there
class EdgeConvNew(nn.Module):
def __init__(self, in_channels, out_channels):
super(EdgeConvNew, self).__init__()
self.in_channels = in_channels
self.conv = nn.Sequential(
nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(negative_slope=0.2),
)
def forward(self, x, k=20):
batch_size = x.size(0)
num_points = x.size(2)
x = x.view(batch_size, -1, num_points)
idx = self.knn(x, k=k) # (batch_size, num_points, k)
idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
_, num_dims, _ = x.size()
x = x.transpose(2, 1).contiguous()
feature = x.view(batch_size*num_points, -1)[idx, :]
feature = feature.view(batch_size, num_points, k, num_dims)
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
feature = self.conv(feature) # (batch_size, num_dims, num_points, k)
return feature
def knn(self, x, k):
x = x.transpose(2, 1)
pairwise_distance = torch.cdist(x, x, p=2)
_, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False) # (batch_size, num_points, k)
return idx
class EdgeConv(nn.Module):
def __init__(self, in_channels, out_channels):
......@@ -91,8 +132,8 @@ class DgcnnClass(nn.Module):
def __init__(self, num_classes):
super(DgcnnClass, self).__init__()
self.transform_net = Transform_Net()
self.edge_conv1 = EdgeConv(3, 64)
self.edge_conv2 = EdgeConv(64, 128)
self.edge_conv1 = EdgeConvNew(3, 64)
self.edge_conv2 = EdgeConvNew(64, 128)
self.bn5 = nn.BatchNorm1d(256)
self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False),
self.bn5,
......@@ -110,18 +151,25 @@ class DgcnnClass(nn.Module):
def forward(self, x):
# Apply Transform_Net on input point cloud
batch_size = x.size(0)
trans_matrix = self.transform_net(x)
x = torch.bmm(x, trans_matrix)
batch_size = x.size(0)
num_points = x.size(1)
dim = x.size(2)
x = x.view(batch_size, dim, num_points)
x1 = self.edge_conv1(x)
x1 = x1.max(dim=-1, keepdim=False)[0]
# print("x1 shape: ", x1.shape)
x2 = self.edge_conv2(x1)
x2 = x2.max(dim=-1, keepdim=False)[0]
# print("x2 shape: ", x2.shape)
x5 = torch.cat((x1, x2), dim=2) # (batch_size, 64+64+128+256, num_points)
x5 = x5.transpose(2, 1) # (batch_size, num_points, 64+64+128+256)
x5 = torch.cat((x1, x2), dim=1) # (batch_size, 64+64+128+256, num_points)
# x5 = x5.transpose(2, 1) # (batch_size, num_points, 64+64+128+256)
# print("x5 shape: ", x5.shape)
x_conv = self.conv5(x5) # (batch_size, 1024, num_points)
# print("x_conv shape: ", x_conv.shape)
......
......@@ -83,21 +83,26 @@ class ShapenetDataDgcnn(object):
with open(json_file, 'r') as f:
data = json.load(f)
print('10 data in the list: ', data[:10])
out_data = []
for i in range(len(data)):
data[i] = os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt'))
out_data.append(os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt')))
# get one class of data
# get the the number of the class airplane
if self.just_one_class:
data = [x for x in data if x.split('/')[-2] in [
out_data = [x for x in out_data if x.split('/')[-2] in [
self.cat['Airplane'],
self.cat['Lamp'],
self.cat['Chair'],
self.cat['Table'],
]]
print('10 data in the out_data list: ', out_data[:10])
return data
return out_data
def get_seg_classes(self, cat):
return self.seg_classes[cat]
......@@ -206,7 +211,13 @@ class ShapenetDataDgcnn(object):
labels = labels.astype(np.int64)
# get the class name
class_name = self.train_file_list[index].split('/')[-2]
if self.split == 'train':
class_name = self.train_file_list[index].split('/')[-2]
elif self.split == 'test':
class_name = self.test_file_list[index].split('/')[-2]
elif self.split == 'val':
class_name = self.val_data_file[index].split('/')[-2]
# apply the mapper
class_name = self.class_mapper(class_name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment