From d1edd85f6850406f23bb123144c49a7e2e1413df Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Wed, 29 Mar 2023 12:34:39 +0200 Subject: [PATCH] updated model dataloader for DGCNN --- dgcnn/dgcnn_train_pl.py | 3 ++- dgcnn/req.txt | 3 +++ dgcnn/shapenet_data_dgcnn.py | 27 +++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 dgcnn/req.txt diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index 67b04b9..65b4fb5 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -140,7 +140,8 @@ shapenet_data_train = ShapenetDataDgcnn( small_data_size=config['data']['small_data_size'], just_one_class=config['data']['just_one_class'], split='train', - norm=config['data']['norm'] + norm=config['data']['norm'], + augmnetation=True ) # get val data diff --git a/dgcnn/req.txt b/dgcnn/req.txt new file mode 100644 index 0000000..efd1de6 --- /dev/null +++ b/dgcnn/req.txt @@ -0,0 +1,3 @@ +torch +wandb +pytorch-lightning==1.9.3 \ No newline at end of file diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index 501caff..8196648 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -19,6 +19,7 @@ class ShapenetDataDgcnn(object): return_cls_label=False, just_one_class=False, norm=False, + augmnetation=False, data_augmentation=False ) -> None: @@ -30,6 +31,7 @@ class ShapenetDataDgcnn(object): self.return_cls_label = return_cls_label self.just_one_class = just_one_class self.norm = norm + self.augmnetation = augmnetation self.data_augmentation = data_augmentation # data operations @@ -143,6 +145,24 @@ class ShapenetDataDgcnn(object): m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) pc = pc / m return pc + + def jitter_pointcloud(self, pointcloud, sigma=0.01, clip=0.02): + N, C = pointcloud.shape + pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) + return pointcloud + + def rotate_pointcloud(self, pointcloud): + theta = np.pi*2 * np.random.uniform() + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]]) + pointcloud[:,[0,2]] = pointcloud[:,[0,2]].dot(rotation_matrix) # random rotation (x,z) + return pointcloud + + def translate_pointcloud(self, pointcloud): + xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) + xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) + + translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') + return translated_pointcloud # TODO: add the selection for a given class @@ -164,6 +184,13 @@ class ShapenetDataDgcnn(object): if self.norm: point_set = self.normalize(point_set) + if self.augmnetation: + point_set = self.jitter_pointcloud(point_set) + point_set = self.rotate_pointcloud(point_set) + point_set = self.translate_pointcloud(point_set) + + + choice = np.random.choice(len(point_set), self.npoints, replace=True) point_set = point_set[choice, :] -- GitLab