Skip to content
Snippets Groups Projects
Commit d1edd85f authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

updated model dataloader for DGCNN

parent 024f1f32
Branches
No related tags found
No related merge requests found
......@@ -140,7 +140,8 @@ shapenet_data_train = ShapenetDataDgcnn(
small_data_size=config['data']['small_data_size'],
just_one_class=config['data']['just_one_class'],
split='train',
norm=config['data']['norm']
norm=config['data']['norm'],
augmnetation=True
)
# get val data
......
torch
wandb
pytorch-lightning==1.9.3
\ No newline at end of file
......@@ -19,6 +19,7 @@ class ShapenetDataDgcnn(object):
return_cls_label=False,
just_one_class=False,
norm=False,
augmnetation=False,
data_augmentation=False
) -> None:
......@@ -30,6 +31,7 @@ class ShapenetDataDgcnn(object):
self.return_cls_label = return_cls_label
self.just_one_class = just_one_class
self.norm = norm
self.augmnetation = augmnetation
self.data_augmentation = data_augmentation
# data operations
......@@ -143,6 +145,24 @@ class ShapenetDataDgcnn(object):
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
def jitter_pointcloud(self, pointcloud, sigma=0.01, clip=0.02):
N, C = pointcloud.shape
pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip)
return pointcloud
def rotate_pointcloud(self, pointcloud):
theta = np.pi*2 * np.random.uniform()
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
pointcloud[:,[0,2]] = pointcloud[:,[0,2]].dot(rotation_matrix) # random rotation (x,z)
return pointcloud
def translate_pointcloud(self, pointcloud):
xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
return translated_pointcloud
# TODO: add the selection for a given class
......@@ -164,6 +184,13 @@ class ShapenetDataDgcnn(object):
if self.norm:
point_set = self.normalize(point_set)
if self.augmnetation:
point_set = self.jitter_pointcloud(point_set)
point_set = self.rotate_pointcloud(point_set)
point_set = self.translate_pointcloud(point_set)
choice = np.random.choice(len(point_set), self.npoints, replace=True)
point_set = point_set[choice, :]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment