diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index 6db78aae24565ad018b4ef18a1dea3424f7db72d..02023e7240a490f1f84f59442411eb2b46e325a5 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -141,7 +141,7 @@ shapenet_data_train = ShapenetDataDgcnn( just_one_class=config['data']['just_one_class'], split='train', norm=config['data']['norm'], - augmnetation=True + augmnetation=config['data']['augmentation'] ) # get val data @@ -152,7 +152,7 @@ shapenet_data_val = ShapenetDataDgcnn( small_data=config['data']['small_data'], small_data_size=config['data']['small_data_size'], just_one_class=config['data']['just_one_class'], - split='train', + split='test', norm=config['data']['norm'] ) diff --git a/dgcnn/model_class.py b/dgcnn/model_class.py index 10ccfce370db7fcf40d210f38fe8011a91f4b48b..31878e728bbad5bad7d14c3e2c1a5341f5a82174 100644 --- a/dgcnn/model_class.py +++ b/dgcnn/model_class.py @@ -5,7 +5,48 @@ import torch.nn.init as init # TODO: update wth https://github.com/antao97/dgcnn.pytorch/blob/07d534c2702905010ec9991619f552d8cacae45b/model.py#L166 # TODO: There are mode conv layers there +class EdgeConvNew(nn.Module): + def __init__(self, in_channels, out_channels): + super(EdgeConvNew, self).__init__() + self.in_channels = in_channels + self.conv = nn.Sequential( + nn.Conv2d(2*in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2), + ) + + def forward(self, x, k=20): + batch_size = x.size(0) + num_points = x.size(2) + x = x.view(batch_size, -1, num_points) + idx = self.knn(x, k=k) # (batch_size, num_points, k) + + idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points + + idx = idx + idx_base + + idx = idx.view(-1) + + _, num_dims, _ = x.size() + + x = x.transpose(2, 1).contiguous() + feature = x.view(batch_size*num_points, -1)[idx, :] + feature = feature.view(batch_size, num_points, k, num_dims) + + x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) + + feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() + feature = self.conv(feature) # (batch_size, num_dims, num_points, k) + + return feature + + + def knn(self, x, k): + x = x.transpose(2, 1) + pairwise_distance = torch.cdist(x, x, p=2) + _, idx = torch.topk(pairwise_distance, k=k, dim=-1, largest=False) # (batch_size, num_points, k) + return idx class EdgeConv(nn.Module): def __init__(self, in_channels, out_channels): @@ -91,8 +132,8 @@ class DgcnnClass(nn.Module): def __init__(self, num_classes): super(DgcnnClass, self).__init__() self.transform_net = Transform_Net() - self.edge_conv1 = EdgeConv(3, 64) - self.edge_conv2 = EdgeConv(64, 128) + self.edge_conv1 = EdgeConvNew(3, 64) + self.edge_conv2 = EdgeConvNew(64, 128) self.bn5 = nn.BatchNorm1d(256) self.conv5 = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False), self.bn5, @@ -110,18 +151,25 @@ class DgcnnClass(nn.Module): def forward(self, x): # Apply Transform_Net on input point cloud - batch_size = x.size(0) trans_matrix = self.transform_net(x) x = torch.bmm(x, trans_matrix) + + batch_size = x.size(0) + num_points = x.size(1) + dim = x.size(2) + + x = x.view(batch_size, dim, num_points) + + x1 = self.edge_conv1(x) x1 = x1.max(dim=-1, keepdim=False)[0] # print("x1 shape: ", x1.shape) x2 = self.edge_conv2(x1) x2 = x2.max(dim=-1, keepdim=False)[0] # print("x2 shape: ", x2.shape) - x5 = torch.cat((x1, x2), dim=2) # (batch_size, 64+64+128+256, num_points) - x5 = x5.transpose(2, 1) # (batch_size, num_points, 64+64+128+256) + x5 = torch.cat((x1, x2), dim=1) # (batch_size, 64+64+128+256, num_points) + # x5 = x5.transpose(2, 1) # (batch_size, num_points, 64+64+128+256) # print("x5 shape: ", x5.shape) x_conv = self.conv5(x5) # (batch_size, 1024, num_points) # print("x_conv shape: ", x_conv.shape) diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index d6f6b029c03281d7726ab15d6906864306ede860..32ed47a1cdf706c81b66b7ac41d786511fb96bb2 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -83,21 +83,26 @@ class ShapenetDataDgcnn(object): with open(json_file, 'r') as f: data = json.load(f) + print('10 data in the list: ', data[:10]) + + out_data = [] for i in range(len(data)): - data[i] = os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt')) + out_data.append(os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt'))) # get one class of data # get the the number of the class airplane if self.just_one_class: - data = [x for x in data if x.split('/')[-2] in [ + out_data = [x for x in out_data if x.split('/')[-2] in [ self.cat['Airplane'], self.cat['Lamp'], self.cat['Chair'], self.cat['Table'], ]] + + print('10 data in the out_data list: ', out_data[:10]) - return data + return out_data def get_seg_classes(self, cat): return self.seg_classes[cat] @@ -206,7 +211,13 @@ class ShapenetDataDgcnn(object): labels = labels.astype(np.int64) # get the class name - class_name = self.train_file_list[index].split('/')[-2] + if self.split == 'train': + class_name = self.train_file_list[index].split('/')[-2] + elif self.split == 'test': + class_name = self.test_file_list[index].split('/')[-2] + elif self.split == 'val': + class_name = self.val_data_file[index].split('/')[-2] + # apply the mapper class_name = self.class_mapper(class_name)