Skip to content
Snippets Groups Projects
Commit 5a2f3f6a authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

point autoencoder started

parent 0441820e
Branches
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
def chamfer_distance(x, y):
"""
Computes the Chamfer distance between two sets of points x and y.
Args:
x (torch.Tensor): Tensor of shape (B, N, D) representing a batch of B point sets, each containing N points in D dimensions.
y (torch.Tensor): Tensor of shape (B, M, D) representing a batch of B point sets, each containing M points in D dimensions.
Returns:
Tuple of two floats: the average distance from each point in x to its nearest neighbor in y, and the average distance from each point in y to its nearest neighbor in x.
"""
# Compute pairwise distances between all points in x and y
dists = torch.cdist(x, y) # shape: (B, N, M)
# Compute the distance from each point in x to its nearest neighbor in y
min_dists_x, _ = torch.min(dists, dim=2) # shape: (B, N)
avg_dist_x = torch.mean(min_dists_x, dim=1, keepdim=True) # shape: ()
# Compute the distance from each point in y to its nearest neighbor in x
min_dists_y, _ = torch.min(dists, dim=1) # shape: (B, M)
avg_dist_y = torch.mean(min_dists_y, dim=1, keepdim=True) # shape: ()
return avg_dist_x + avg_dist_y
def chamfer_distance_simple(point_cloud1, point_cloud2):
"""
Computes the Chamfer Distance between two point clouds.
Args:
point_cloud1 (torch.Tensor): First point cloud with shape (B, N, 3), where B is batch size, N is the number of points.
point_cloud2 (torch.Tensor): Second point cloud with shape (B, M, 3), where B is batch size, M is the number of points.
Returns:
chamfer_dist (torch.Tensor): Chamfer Distance between the two point clouds of shape (B,).
"""
B, N, _ = point_cloud1.size()
_, M, _ = point_cloud2.size()
point_cloud1 = point_cloud1.unsqueeze(dim=2) # Shape: (B, N, 1, 3)
point_cloud2 = point_cloud2.unsqueeze(dim=1) # Shape: (B, 1, M, 3)
point_cloud1 = point_cloud1.repeat(1, 1, M, 1) # Shape: (B, N, M, 3)
point_cloud2 = point_cloud2.repeat(1, N, 1, 1) # Shape: (B, N, M, 3)
# Compute squared L2 distances between all pairs of points in the two point clouds
squared_distances = torch.sum((point_cloud1 - point_cloud2) ** 2, dim=-1) # Shape: (B, N, M)
# Compute the Chamfer Distance
min_dist_1, _ = torch.min(squared_distances, dim=2) # Shape: (B, N)
min_dist_2, _ = torch.min(squared_distances, dim=1) # Shape: (B, M)
chamfer_dist = torch.mean(min_dist_1, dim=1) + torch.mean(min_dist_2, dim=1) # Shape: (B,)
return chamfer_dist
"""
model copied from Fxia22
"""
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
class PCAutoEncoder(nn.Module):
""" Autoencoder for Point Cloud
Input:
Output:
"""
def __init__(self, point_dim, num_points):
super(PCAutoEncoder, self).__init__()
self.conv1 = nn.Conv1d(in_channels=point_dim, out_channels=64, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
self.conv3 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1)
self.conv4 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1)
self.conv5 = nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1)
self.fc1 = nn.Linear(in_features=1024, out_features=1024)
self.fc2 = nn.Linear(in_features=1024, out_features=1024)
self.fc3 = nn.Linear(in_features=1024, out_features=num_points*3)
#batch norm
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
def forward(self, x):
batch_size = x.shape[0]
point_dim = x.shape[1]
num_points = x.shape[2]
#encoder
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn1(self.conv2(x)))
x = F.relu(self.bn1(self.conv3(x)))
x = F.relu(self.bn2(self.conv4(x)))
x = F.relu(self.bn3(self.conv5(x)))
# do max pooling
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
# get the global embedding
global_feat = x
#decoder
x = F.relu(self.bn3(self.fc1(x)))
x = F.relu(self.bn3(self.fc2(x)))
reconstructed_points = self.fc3(x)
#do reshaping
reconstructed_points = reconstructed_points.reshape(batch_size, point_dim, num_points)
return reconstructed_points , global_feat
class PointNetAE(nn.Module):
def __init__(self, num_points = 2048, k = 2):
super(PointNetAE, self).__init__()
self.num_points = num_points
# self.encoder = nn.Sequential(
# PointNetfeat(num_points, global_feat=True, trans = False),
# nn.Linear(1024, 256),
# nn.ReLU(),
# nn.Linear(256, 100),
# )
self.encoder = PointEncoder(num_points)
self.decoder = PointDecoder(num_points)
def forward(self, x):
x = self.encoder(x)
encoded_embedding = x
x = self.decoder(x)
return x, encoded_embedding
class STN3d(nn.Module):
"""Spatial Transformer Network for 3D point clouds."""
def __init__(self, num_points=2500):
"""
Args:
num_points (int): number of input points
"""
super(STN3d, self).__init__()
self.num_points = num_points
self.conv1 = nn.Conv1d(3, 64, 1)
self.conv2 = nn.Conv1d(64, 128, 1)
self.conv3 = nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 9)
self.relu = nn.ReLU()
def forward(self, x):
"""
Args:
x (torch.Tensor): input point cloud of shape (batch_size, 3, num_points)
Returns:
torch.Tensor: affine transformation matrix of shape (batch_size, 3, 3)
"""
batchsize = x.size()[0]
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x, _ = torch.max(x, 2)
x = x.view(-1, 1024)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
iden = torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1], dtype=np.float32)).view(1, 9).repeat(batchsize, 1)
if x.is_cuda:
iden = iden.cuda()
x = x + iden
x = x.view(-1, 3, 3)
return x
class PointNetfeat(nn.Module):
def __init__(self, num_points = 2500, global_feat = True, trans = True):
super(PointNetfeat, self).__init__()
self.stn = STN3d(num_points = num_points)
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = torch.nn.BatchNorm1d(64)
self.bn2 = torch.nn.BatchNorm1d(128)
self.bn3 = torch.nn.BatchNorm1d(1024)
self.trans = trans
#self.mp1 = torch.nn.MaxPool1d(num_points)
self.num_points = num_points
self.global_feat = global_feat
def forward(self, x):
batchsize = x.size()[0]
if self.trans:
trans = self.stn(x)
x = x.transpose(2,1)
x = torch.bmm(x, trans)
x = x.transpose(2,1)
x = F.relu(self.bn1(self.conv1(x)))
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x,_ = torch.max(x, 2)
x = x.view(-1, 1024)
if self.trans:
if self.global_feat:
return x, trans
else:
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
return torch.cat([x, pointfeat], 1), trans
else:
return x
class PointEncoder(nn.Module):
def __init__(self, num_points = 2500):
super(PointEncoder, self).__init__()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = torch.nn.Linear(1024, 256)
self.fc2 = torch.nn.Linear(256, 100)
self.bn1 = torch.nn.BatchNorm1d(64)
self.bn2 = torch.nn.BatchNorm1d(128)
self.bn3 = torch.nn.BatchNorm1d(1024)
#self.mp1 = torch.nn.MaxPool1d(num_points)
self.num_points = num_points
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x,_ = torch.max(x, 2)
x = x.view(-1, 1024)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class PointDecoder(nn.Module):
def __init__(self, num_points = 2048, k = 2):
super(PointDecoder, self).__init__()
self.num_points = num_points
self.fc1 = nn.Linear(100, 128)
self.fc2 = nn.Linear(128, 256)
self.fc3 = nn.Linear(256, 512)
self.fc4 = nn.Linear(512, 1024)
self.fc5 = nn.Linear(1024, self.num_points * 3)
self.th = nn.Tanh()
def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = self.th(self.fc5(x))
x = x.view(batchsize, 3, self.num_points)
return x
if __name__ == "__main__":
print("hello model")
\ No newline at end of file
import json
import os
import numpy as np
class ShapenetData(object):
"""
The is the data loader for the ShapeNet dataset. Only for data segmentation, not for classification.
"""
def __init__(self,
root,
npoints=1200,
split='train',
small_data=False,
small_data_size=10,
just_one_class=False,
norm=True,
data_augmentation=False
) -> None:
self.npoints = npoints
self.root = root
self.split = split
self.small_data = small_data
self.small_data_size = small_data_size
self.just_one_class = just_one_class
self.norm = norm
self.data_augmentation = data_augmentation
# data operations
self.train_data_suffled_list_file = os.path.join(
self.root,
'raw',
'train_test_split',
'shuffled_train_file_list.json')
self.test_data_suffled_list_file = os.path.join(
self.root,
'raw',
'train_test_split',
'shuffled_test_file_list.json')
self.val_data_file = os.path.join(
self.root,
'raw',
'train_test_split',
'shuffled_val_file_list.json')
self.catfile = os.path.join(self.root,
'raw',
'synsetoffset2category.txt')
self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
self.cat = {}
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
self.train_file_list = self.get_list_shuffled_data(self.root, self.train_data_suffled_list_file)
self.test_file_list = self.get_list_shuffled_data(self.root, self.test_data_suffled_list_file)
self.val_data_file = self.get_list_shuffled_data(self.root, self.val_data_file)
# take the first 10 data for training
if self.small_data:
self.train_file_list = self.train_file_list[:self.small_data_size]
self.test_file_list = self.test_file_list[:self.small_data_size]
self.val_data_file = self.val_data_file[:self.small_data_size]
def get_list_shuffled_data(self,root, json_file):
# read the json file and return the list of data
with open(json_file, 'r') as f:
data = json.load(f)
for i in range(len(data)):
data[i] = os.path.join(root, 'raw', os.path.join(data[i].split('/')[-2], data[i].split('/')[-1] + '.txt'))
# get one class of data
# get the the number of the class airplane
if self.just_one_class:
data = [x for x in data if x.split('/')[-2] == self.cat['Airplane']]
return data
def get_seg_classes(self, cat):
return self.seg_classes[cat]
def get_class_names(self):
return list(self.cat.values())
def get_all_names_of_classes(self):
return list(self.cat.keys())
def normalize(self, pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
# TODO: add the selection for a given class
def __getitem__(self, index):
if self.split == 'train':
point_set = np.loadtxt(self.train_file_list[index]).astype(np.float32)
elif self.split == 'test':
point_set = np.loadtxt(self.test_file_list[index]).astype(np.float32)
elif self.split == 'val':
point_set = np.loadtxt(self.val_data_file[index]).astype(np.float32)
# get just the points
point_set = point_set[:, 0:3]
# normalize the points
if self.norm:
point_set = self.normalize(point_set)
# choice = np.random.choice(len(point_set), self.npoints, replace=True)
# chose the first npoints
choice = np.arange(self.npoints)
point_set = point_set[choice, :]
point_set = point_set.astype(np.float32)
return point_set
def __len__(self):
if self.split == 'train':
return len(self.train_file_list)
elif self.split == 'test':
return len(self.test_file_list)
elif self.split == 'val':
return len(self.val_data_file)
if __name__ == "__main__":
shapenet_data = ShapenetData(root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet', split='train')
# print(shapenet_data.train_file_list)
print(shapenet_data.get_seg_classes('Car'))
print(shapenet_data.get_class_names())
print(shapenet_data.get_all_names_of_classes())
# get the first point cloud
point_cloud = shapenet_data[0]
print(point_cloud[:10])
\ No newline at end of file
import os
from shapenet_data import ShapenetData
from torch.utils.data import DataLoader
import torch.nn.init as init
import torch.nn as nn
import torch
import wandb
# from chamfer_distance import ChamferDistance
# add wandb
wandb.init(project="forest-point-autoencoder", entity="maciej-wielgosz-nibio")
from helpers.chamfer_distance import chamfer_distance, chamfer_distance_simple
from models.model import PointNetAE, PCAutoEncoder
n_points = 400
train_ds = ShapenetData(
root='/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet',
split='train',
npoints=n_points,
small_data=True,
small_data_size=300,
just_one_class=False,
norm=True
)
# create a dataloader
train_dl = DataLoader(train_ds, batch_size=16, shuffle=False, num_workers=4)
# create a model
model = PointNetAE(num_points=n_points, k=2)
# model = PCAutoEncoder(point_dim=3, num_points=8)
model = model.train()
# Setting up Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.5)
# intialize weights
def init_weights_xavier(m):
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.zeros_(m.bias)
model.apply(init_weights_xavier)
# create folder for trained models to be saved
path_for_save = '/home/nibio/mutable-outside-world/code/oracle_gpu_runs/point_net_autoencoder'
path_for_save = os.path.join(path_for_save, 'saved_models')
os.makedirs(path_for_save, exist_ok=True)
# device = torch.device("cpu")
if torch.cuda.is_available():
print("Using the GPU")
device = torch.device("cuda")
# put the model on the GPU
model.to(device)
# chamfer_dist_otaheri = ChamferDistance()
# Start the Training
for epoch in range(4000):
for i, data in enumerate(train_dl):
# print("data: ", data)
points = data.to(device)
points = points.transpose(2, 1)
# points = points.cuda()
optimizer.zero_grad()
reconstructed_points, global_feat = model(points)
# train_loss= chamfer_distance_simple(points, reconstructed_points).mean()
train_loss = torch.nn.functional.mse_loss(reconstructed_points, points)
# chamfer_dist_otaheri(points, reconstructed_points)
# dist1, dist2, idx1, idx2 = chamfer_dist_otaheri(points, reconstructed_points)
# train_loss = (torch.mean(dist1)) + (torch.mean(dist2))
wandb.log({"train_loss": train_loss})
print(f"Epoch: {epoch}, Iteration#: {i}, Train Loss: {train_loss}")
print("current learning rate: ", optimizer.param_groups[0]['lr'])
# Calculate the gradients using Back Propogation
train_loss.backward()
# Update the weights and biases
optimizer.step()
# log epochs
wandb.log({"epoch": epoch})
scheduler.step()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment