From 5a2f3f6a2ef9215c9be4ee7c08935def86a87e4d Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Fri, 17 Mar 2023 15:35:13 +0100
Subject: [PATCH] point autoencoder started

---
 point_net_autoencoder/helpers/__init__.py     |   0
 .../helpers/chamfer_distance.py               |  62 +++++
 point_net_autoencoder/models/__init__.py      |   0
 point_net_autoencoder/models/model.py         | 221 ++++++++++++++++++
 point_net_autoencoder/shapenet_data.py        | 154 ++++++++++++
 point_net_autoencoder/train_shapenet.py       | 107 +++++++++
 6 files changed, 544 insertions(+)
 create mode 100644 point_net_autoencoder/helpers/__init__.py
 create mode 100644 point_net_autoencoder/helpers/chamfer_distance.py
 create mode 100644 point_net_autoencoder/models/__init__.py
 create mode 100644 point_net_autoencoder/models/model.py
 create mode 100644 point_net_autoencoder/shapenet_data.py
 create mode 100644 point_net_autoencoder/train_shapenet.py

diff --git a/point_net_autoencoder/helpers/__init__.py b/point_net_autoencoder/helpers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/point_net_autoencoder/helpers/chamfer_distance.py b/point_net_autoencoder/helpers/chamfer_distance.py
new file mode 100644
index 0000000..e802a70
--- /dev/null
+++ b/point_net_autoencoder/helpers/chamfer_distance.py
@@ -0,0 +1,62 @@
+import torch
+import torch.nn.functional as F
+
+def chamfer_distance(x, y):
+    """
+    Computes the Chamfer distance between two sets of points x and y.
+
+    Args:
+        x (torch.Tensor): Tensor of shape (B, N, D) representing a batch of B point sets, each containing N points in D dimensions.
+        y (torch.Tensor): Tensor of shape (B, M, D) representing a batch of B point sets, each containing M points in D dimensions.
+
+    Returns:
+        Tuple of two floats: the average distance from each point in x to its nearest neighbor in y, and the average distance from each point in y to its nearest neighbor in x.
+    """
+    # Compute pairwise distances between all points in x and y
+    dists = torch.cdist(x, y)  # shape: (B, N, M)
+
+    # Compute the distance from each point in x to its nearest neighbor in y
+    min_dists_x, _ = torch.min(dists, dim=2)  # shape: (B, N)
+    avg_dist_x = torch.mean(min_dists_x, dim=1, keepdim=True)  # shape: ()
+
+    # Compute the distance from each point in y to its nearest neighbor in x
+    min_dists_y, _ = torch.min(dists, dim=1)  # shape: (B, M)
+    avg_dist_y = torch.mean(min_dists_y, dim=1, keepdim=True)  # shape: ()
+
+    return avg_dist_x + avg_dist_y
+
+
+
+def chamfer_distance_simple(point_cloud1, point_cloud2):
+    """
+    Computes the Chamfer Distance between two point clouds.
+    
+    Args:
+        point_cloud1 (torch.Tensor): First point cloud with shape (B, N, 3), where B is batch size, N is the number of points.
+        point_cloud2 (torch.Tensor): Second point cloud with shape (B, M, 3), where B is batch size, M is the number of points.
+
+    Returns:
+        chamfer_dist (torch.Tensor): Chamfer Distance between the two point clouds of shape (B,).
+    """
+    B, N, _ = point_cloud1.size()
+    _, M, _ = point_cloud2.size()
+    
+    point_cloud1 = point_cloud1.unsqueeze(dim=2)  # Shape: (B, N, 1, 3)
+    point_cloud2 = point_cloud2.unsqueeze(dim=1)  # Shape: (B, 1, M, 3)
+    
+    point_cloud1 = point_cloud1.repeat(1, 1, M, 1)  # Shape: (B, N, M, 3)
+    point_cloud2 = point_cloud2.repeat(1, N, 1, 1)  # Shape: (B, N, M, 3)
+    
+    # Compute squared L2 distances between all pairs of points in the two point clouds
+    squared_distances = torch.sum((point_cloud1 - point_cloud2) ** 2, dim=-1)  # Shape: (B, N, M)
+    
+    # Compute the Chamfer Distance
+    min_dist_1, _ = torch.min(squared_distances, dim=2)  # Shape: (B, N)
+    min_dist_2, _ = torch.min(squared_distances, dim=1)  # Shape: (B, M)
+    
+    chamfer_dist = torch.mean(min_dist_1, dim=1) + torch.mean(min_dist_2, dim=1)  # Shape: (B,)
+
+    return chamfer_dist
+
+
+
diff --git a/point_net_autoencoder/models/__init__.py b/point_net_autoencoder/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/point_net_autoencoder/models/model.py b/point_net_autoencoder/models/model.py
new file mode 100644
index 0000000..ff68cb1
--- /dev/null
+++ b/point_net_autoencoder/models/model.py
@@ -0,0 +1,221 @@
+"""
+model copied from Fxia22
+
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+
+
+class PCAutoEncoder(nn.Module):
+    """ Autoencoder for Point Cloud 
+    Input: 
+    Output: 
+    """
+    def __init__(self, point_dim, num_points):
+        super(PCAutoEncoder, self).__init__()
+
+        self.conv1 = nn.Conv1d(in_channels=point_dim, out_channels=64, kernel_size=1)
+        self.conv2 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
+        self.conv3 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
+        self.conv4 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1)
+        self.conv5 = nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1)
+
+        self.fc1 = nn.Linear(in_features=1024, out_features=1024)
+        self.fc2 = nn.Linear(in_features=1024, out_features=1024)
+        self.fc3 = nn.Linear(in_features=1024, out_features=num_points*3)
+
+        #batch norm
+        self.bn1 = nn.BatchNorm1d(64)
+        self.bn2 = nn.BatchNorm1d(128)
+        self.bn3 = nn.BatchNorm1d(1024)
+    
+    def forward(self, x):
+
+        batch_size = x.shape[0]
+        point_dim = x.shape[1]
+        num_points = x.shape[2]
+
+        #encoder
+        x = F.relu(self.bn1(self.conv1(x)))
+        x = F.relu(self.bn1(self.conv2(x)))
+        x = F.relu(self.bn1(self.conv3(x)))
+        x = F.relu(self.bn2(self.conv4(x)))
+        x = F.relu(self.bn3(self.conv5(x)))
+
+        # do max pooling 
+        x = torch.max(x, 2, keepdim=True)[0]
+        x = x.view(-1, 1024)
+        # get the global embedding
+        global_feat = x
+
+        #decoder
+        x = F.relu(self.bn3(self.fc1(x)))
+        x = F.relu(self.bn3(self.fc2(x)))
+        reconstructed_points = self.fc3(x)
+
+        #do reshaping
+        reconstructed_points = reconstructed_points.reshape(batch_size, point_dim, num_points)
+
+        return reconstructed_points , global_feat
+
+class PointNetAE(nn.Module):
+    def __init__(self, num_points = 2048, k = 2):
+        super(PointNetAE, self).__init__()
+        self.num_points = num_points
+        # self.encoder = nn.Sequential(
+        # PointNetfeat(num_points, global_feat=True, trans = False),
+        # nn.Linear(1024, 256),
+        # nn.ReLU(),
+        # nn.Linear(256, 100),
+        # )
+
+        self.encoder = PointEncoder(num_points)
+        self.decoder = PointDecoder(num_points)
+
+    def forward(self, x):
+
+        x = self.encoder(x)
+
+        encoded_embedding = x 
+        
+        x = self.decoder(x)
+
+        return x, encoded_embedding
+
+class STN3d(nn.Module):
+    """Spatial Transformer Network for 3D point clouds."""
+
+    def __init__(self, num_points=2500):
+        """
+        Args:
+            num_points (int): number of input points
+        """
+        super(STN3d, self).__init__()
+        self.num_points = num_points
+        self.conv1 = nn.Conv1d(3, 64, 1)
+        self.conv2 = nn.Conv1d(64, 128, 1)
+        self.conv3 = nn.Conv1d(128, 1024, 1)
+        self.fc1 = nn.Linear(1024, 512)
+        self.fc2 = nn.Linear(512, 256)
+        self.fc3 = nn.Linear(256, 9)
+        self.relu = nn.ReLU()
+
+    def forward(self, x):
+        """
+        Args:
+            x (torch.Tensor): input point cloud of shape (batch_size, 3, num_points)
+
+        Returns:
+            torch.Tensor: affine transformation matrix of shape (batch_size, 3, 3)
+        """
+        batchsize = x.size()[0]
+        x = self.relu(self.conv1(x))
+        x = self.relu(self.conv2(x))
+        x = self.relu(self.conv3(x))
+        x, _ = torch.max(x, 2)
+        x = x.view(-1, 1024)
+        x = self.relu(self.fc1(x))
+        x = self.relu(self.fc2(x))
+        x = self.fc3(x)
+
+        iden = torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1], dtype=np.float32)).view(1, 9).repeat(batchsize, 1)
+        if x.is_cuda:
+            iden = iden.cuda()
+        x = x + iden
+        x = x.view(-1, 3, 3)
+        return x
+
+class PointNetfeat(nn.Module):
+    def __init__(self, num_points = 2500, global_feat = True, trans = True):
+        super(PointNetfeat, self).__init__()
+        self.stn = STN3d(num_points = num_points)
+        self.conv1 = torch.nn.Conv1d(3, 64, 1)
+        self.conv2 = torch.nn.Conv1d(64, 128, 1)
+        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
+
+        self.bn1 = torch.nn.BatchNorm1d(64)
+        self.bn2 = torch.nn.BatchNorm1d(128)
+        self.bn3 = torch.nn.BatchNorm1d(1024)
+        self.trans = trans
+
+
+        #self.mp1 = torch.nn.MaxPool1d(num_points)
+        self.num_points = num_points
+        self.global_feat = global_feat
+    def forward(self, x):
+        batchsize = x.size()[0]
+        if self.trans:
+            trans = self.stn(x)
+            x = x.transpose(2,1)
+            x = torch.bmm(x, trans)
+            x = x.transpose(2,1)
+        x = F.relu(self.bn1(self.conv1(x)))
+        pointfeat = x
+        x = F.relu(self.bn2(self.conv2(x)))
+        x = self.bn3(self.conv3(x))
+        x,_ = torch.max(x, 2)
+        x = x.view(-1, 1024)
+        if self.trans:
+            if self.global_feat:
+                return x, trans
+            else:
+                x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
+                return torch.cat([x, pointfeat], 1), trans
+        else:
+            return x
+
+class PointEncoder(nn.Module):
+    def __init__(self, num_points = 2500):
+        super(PointEncoder, self).__init__()
+        self.conv1 = torch.nn.Conv1d(3, 64, 1)
+        self.conv2 = torch.nn.Conv1d(64, 128, 1)
+        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
+
+        self.fc1 = torch.nn.Linear(1024, 256)
+        self.fc2 = torch.nn.Linear(256, 100)
+
+        self.bn1 = torch.nn.BatchNorm1d(64)
+        self.bn2 = torch.nn.BatchNorm1d(128)
+        self.bn3 = torch.nn.BatchNorm1d(1024)
+
+        #self.mp1 = torch.nn.MaxPool1d(num_points)
+        self.num_points = num_points
+
+    def forward(self, x):
+        x = F.relu(self.bn1(self.conv1(x)))
+        x = F.relu(self.bn2(self.conv2(x)))
+        x = self.bn3(self.conv3(x))
+        x,_ = torch.max(x, 2)
+        x = x.view(-1, 1024)
+        x = F.relu(self.fc1(x))
+        x = self.fc2(x)
+
+        return x
+
+class PointDecoder(nn.Module):
+    def __init__(self, num_points = 2048, k = 2):
+        super(PointDecoder, self).__init__()
+        self.num_points = num_points
+        self.fc1 = nn.Linear(100, 128)
+        self.fc2 = nn.Linear(128, 256)
+        self.fc3 = nn.Linear(256, 512)
+        self.fc4 = nn.Linear(512, 1024)
+        self.fc5 = nn.Linear(1024, self.num_points * 3)
+        self.th = nn.Tanh()
+    def forward(self, x):
+        batchsize = x.size()[0]
+        x = F.relu(self.fc1(x))
+        x = F.relu(self.fc2(x))
+        x = F.relu(self.fc3(x))
+        x = F.relu(self.fc4(x))
+        x = self.th(self.fc5(x))
+        x = x.view(batchsize, 3, self.num_points)
+        return x
+    
+
+
+if __name__ == "__main__":
+    print("hello model")
\ No newline at end of file
diff --git a/point_net_autoencoder/shapenet_data.py b/point_net_autoencoder/shapenet_data.py
new file mode 100644
index 0000000..7cd999b
--- /dev/null
+++ b/point_net_autoencoder/shapenet_data.py
@@ -0,0 +1,154 @@
+
+import json
+import os
+
+import numpy as np
+
+
+class ShapenetData(object):
+    """
+    The is the data loader for the ShapeNet dataset. Only for data segmentation, not for classification.
+    """
+    def __init__(self,
+                 root,
+                 npoints=1200,
+                 split='train',
+                 small_data=False,
+                 small_data_size=10,
+                 just_one_class=False,
+                 norm=True,
+                 data_augmentation=False
+                 ) -> None:
+        
+        self.npoints = npoints
+        self.root = root
+        self.split = split
+        self.small_data = small_data
+        self.small_data_size = small_data_size
+        self.just_one_class = just_one_class
+        self.norm = norm
+        self.data_augmentation = data_augmentation
+
+        # data operations
+        self.train_data_suffled_list_file = os.path.join(
+            self.root, 
+            'raw',
+            'train_test_split',
+            'shuffled_train_file_list.json')
+        self.test_data_suffled_list_file = os.path.join(
+            self.root, 
+            'raw',
+            'train_test_split',
+            'shuffled_test_file_list.json')
+        self.val_data_file = os.path.join(
+            self.root, 
+            'raw',
+            'train_test_split',
+            'shuffled_val_file_list.json')
+        
+
+        self.catfile = os.path.join(self.root, 
+                                    'raw',
+                                    'synsetoffset2category.txt')
+        self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
+                        'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
+                        'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
+                        'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
+                        'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
+        
+        self.cat = {}
+        with open(self.catfile, 'r') as f:
+            for line in f:
+                ls = line.strip().split()
+                self.cat[ls[0]] = ls[1]
+
+        self.train_file_list = self.get_list_shuffled_data(self.root, self.train_data_suffled_list_file)
+        self.test_file_list = self.get_list_shuffled_data(self.root, self.test_data_suffled_list_file)
+        self.val_data_file = self.get_list_shuffled_data(self.root, self.val_data_file)
+
+        # take the first 10 data for training
+        if self.small_data:
+            self.train_file_list = self.train_file_list[:self.small_data_size]
+            self.test_file_list = self.test_file_list[:self.small_data_size]
+            self.val_data_file = self.val_data_file[:self.small_data_size]
+        
+
+    def get_list_shuffled_data(self,root, json_file):
+       # read the json file and return the list of data
+        with open(json_file, 'r') as f:
+            data = json.load(f)
+
+        for i in range(len(data)):
+            data[i] = os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt'))
+
+        # get one class of data
+        # get the the number of the class airplane
+
+        if self.just_one_class:
+            data = [x for x in data if x.split('/')[-2] == self.cat['Airplane']]
+
+        return data
+    
+    def get_seg_classes(self, cat):
+        return self.seg_classes[cat]
+    
+    def get_class_names(self):
+        return list(self.cat.values())
+    
+    def get_all_names_of_classes(self):
+        return list(self.cat.keys())
+    
+    def normalize(self, pc):
+        centroid = np.mean(pc, axis=0)
+        pc = pc - centroid
+        m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
+        pc = pc / m
+        return pc
+        
+    # TODO: add the selection for a given class
+
+    def __getitem__(self, index):
+        if self.split == 'train':
+            point_set = np.loadtxt(self.train_file_list[index]).astype(np.float32)
+        elif self.split == 'test':
+            point_set = np.loadtxt(self.test_file_list[index]).astype(np.float32)
+        elif self.split == 'val':
+            point_set = np.loadtxt(self.val_data_file[index]).astype(np.float32)
+
+        # get just the points
+        point_set = point_set[:, 0:3]
+
+        # normalize the points
+        if self.norm:
+            point_set = self.normalize(point_set)
+
+        # choice = np.random.choice(len(point_set), self.npoints, replace=True)
+        # chose the first npoints
+        choice = np.arange(self.npoints)
+
+        point_set = point_set[choice, :]
+        point_set = point_set.astype(np.float32)
+
+        return point_set
+
+    def __len__(self):
+        if self.split == 'train':
+            return len(self.train_file_list)
+        elif self.split == 'test':
+            return len(self.test_file_list)
+        elif self.split == 'val':
+            return len(self.val_data_file)
+
+
+if __name__ == "__main__":
+
+  shapenet_data = ShapenetData(root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', split='train')
+#   print(shapenet_data.train_file_list)
+  print(shapenet_data.get_seg_classes('Car'))
+  print(shapenet_data.get_class_names())
+  print(shapenet_data.get_all_names_of_classes())
+
+  # get the first point cloud
+  point_cloud = shapenet_data[0]
+
+  print(point_cloud[:10])
\ No newline at end of file
diff --git a/point_net_autoencoder/train_shapenet.py b/point_net_autoencoder/train_shapenet.py
new file mode 100644
index 0000000..bee49fb
--- /dev/null
+++ b/point_net_autoencoder/train_shapenet.py
@@ -0,0 +1,107 @@
+import os
+from shapenet_data import ShapenetData
+from torch.utils.data import DataLoader
+
+import torch.nn.init as init
+import torch.nn as nn
+
+import torch
+import wandb
+
+# from chamfer_distance import ChamferDistance
+
+# add wandb
+wandb.init(project="forest-point-autoencoder", entity="maciej-wielgosz-nibio")
+
+
+
+from helpers.chamfer_distance import chamfer_distance, chamfer_distance_simple
+
+from models.model import PointNetAE, PCAutoEncoder
+
+n_points = 400
+
+train_ds = ShapenetData(
+    root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', 
+    split='train', 
+    npoints=n_points, 
+    small_data=True,
+    small_data_size=300,
+    just_one_class=False,
+    norm=True
+    )
+
+# create a dataloader
+train_dl = DataLoader(train_ds, batch_size=16, shuffle=False, num_workers=4)
+
+# create a model
+model = PointNetAE(num_points=n_points, k=2)
+
+# model = PCAutoEncoder(point_dim=3, num_points=8)
+
+model = model.train()
+
+# Setting up Optimizer
+optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
+scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.5)
+
+# intialize weights
+
+def init_weights_xavier(m):
+    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
+        init.xavier_uniform_(m.weight)
+        if m.bias is not None:
+            init.zeros_(m.bias)
+
+model.apply(init_weights_xavier)
+
+
+# create folder for trained models to be saved
+path_for_save = '/home/nibio/mutable-outside-world/code/oracle_gpu_runs/point_net_autoencoder'
+path_for_save = os.path.join(path_for_save, 'saved_models')
+os.makedirs(path_for_save, exist_ok=True)
+
+# device = torch.device("cpu")
+
+if torch.cuda.is_available():
+    print("Using the GPU")
+    device = torch.device("cuda")
+
+# put the model on the GPU
+model.to(device)
+
+# chamfer_dist_otaheri = ChamferDistance()
+
+
+# Start the Training 
+for epoch in range(4000):
+    for i, data in enumerate(train_dl):
+        # print("data: ", data)
+        points = data.to(device)
+        points = points.transpose(2, 1)
+        # points = points.cuda()
+        optimizer.zero_grad()
+        reconstructed_points, global_feat = model(points)
+
+        # train_loss= chamfer_distance_simple(points, reconstructed_points).mean()
+        train_loss = torch.nn.functional.mse_loss(reconstructed_points, points)
+        # chamfer_dist_otaheri(points, reconstructed_points)
+        # dist1, dist2, idx1, idx2 = chamfer_dist_otaheri(points, reconstructed_points)
+        # train_loss = (torch.mean(dist1)) + (torch.mean(dist2))
+
+        wandb.log({"train_loss": train_loss})
+
+        print(f"Epoch: {epoch}, Iteration#: {i}, Train Loss: {train_loss}")
+        print("current learning rate: ", optimizer.param_groups[0]['lr'])
+
+        # Calculate the gradients using Back Propogation
+        train_loss.backward() 
+
+        # Update the weights and biases 
+        optimizer.step()
+
+    # log epochs
+    wandb.log({"epoch": epoch})
+
+        
+    scheduler.step()
\ No newline at end of file
-- 
GitLab