diff --git a/dgcnn/model.py b/dgcnn/model.py index 3d96217d76dd3d4b48b2a744d3788fb17c45947b..f90b1e2fd2e36b42f8852af028030b5736ec9504 100644 --- a/dgcnn/model.py +++ b/dgcnn/model.py @@ -95,7 +95,8 @@ class DGCNN(nn.Module): self.edge_conv3 = EdgeConv(128, 256) self.edge_conv4 = EdgeConv(256, 512) self.bn5 = nn.BatchNorm1d(1024) - self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), + self.bn6 = nn.BatchNorm1d(512) + self.conv5 = nn.Sequential(nn.Conv1d(960, 1024, kernel_size=1, bias=False), self.bn5, nn.LeakyReLU(negative_slope=0.2)) self.linear1 = nn.Linear(2048, 512, bias=False) @@ -130,12 +131,19 @@ class DGCNN(nn.Module): x4 = self.edge_conv4(x3) x4 = x4.max(dim=-1, keepdim=False)[0] # print("x4 shape: ", x4.shape) - # x5 = torch.cat((x1, x2, x3, x4), dim=1) # (batch_size, 64+64+128+256, num_points) - # x6 = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims) - # x7 = F.adaptive_avg_pool1d(x5, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims) - # x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2) - - # x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512) + x5 = torch.cat((x1, x2, x3, x4), dim=2) # (batch_size, 64+64+128+256, num_points) + x5 = x5.transpose(2, 1) # (batch_size, num_points, 64+64+128+256) + # print("x5 shape: ", x5.shape) + x_conv = self.conv5(x5) # (batch_size, 1024, num_points) + # print("x_conv shape: ", x_conv.shape) + x6 = F.adaptive_max_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims) + # print("x6 shape: ", x6.shape) + x7 = F.adaptive_avg_pool1d(x_conv, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims) + # print("x7 shape: ", x7.shape) + x8 = torch.cat((x6, x7), 1) # (batch_size, emb_dims*2) + + x8 = F.leaky_relu(self.bn6(self.linear1(x8)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512) + x9 = torch.max(x4, dim=1, keepdim=True)[0] x10 = self.fc(x9.squeeze(1)) diff --git a/requirements.txt b/requirements.txt index 81d6d253f102c38dd61cd16a23bd7fe7901fc4a5..206578ec3cb384c82d323add238634bf18d8bdae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,13 @@ numpy +torch +wandb tqdm hydra-core omegaconf -laspy \ No newline at end of file +laspy +pyg_lib +torch_scatter +torch_sparse +torch_cluster +torch_spline_conv +torch_geometric -f https://data.pyg.org/whl/torch-1.13.0+cu117.html \ No newline at end of file