diff --git a/dgcnn/dgcnn_train_pl.py b/dgcnn/dgcnn_train_pl.py index 67b04b9dfa781cdc97d87e50528183f47040a169..65b4fb5d1dcc18fb945c8edc81a8a019d44dd4a2 100644 --- a/dgcnn/dgcnn_train_pl.py +++ b/dgcnn/dgcnn_train_pl.py @@ -140,7 +140,8 @@ shapenet_data_train = ShapenetDataDgcnn( small_data_size=config['data']['small_data_size'], just_one_class=config['data']['just_one_class'], split='train', - norm=config['data']['norm'] + norm=config['data']['norm'], + augmnetation=True ) # get val data diff --git a/dgcnn/req.txt b/dgcnn/req.txt new file mode 100644 index 0000000000000000000000000000000000000000..efd1de65cdc9a1d743c4f492f8d20252386a51d9 --- /dev/null +++ b/dgcnn/req.txt @@ -0,0 +1,3 @@ +torch +wandb +pytorch-lightning==1.9.3 \ No newline at end of file diff --git a/dgcnn/shapenet_data_dgcnn.py b/dgcnn/shapenet_data_dgcnn.py index 501caffc874c0e009fd840816e09473ca5a8f8e3..81966484801040ddc0e9c3a77379187f5ef3e880 100644 --- a/dgcnn/shapenet_data_dgcnn.py +++ b/dgcnn/shapenet_data_dgcnn.py @@ -19,6 +19,7 @@ class ShapenetDataDgcnn(object): return_cls_label=False, just_one_class=False, norm=False, + augmnetation=False, data_augmentation=False ) -> None: @@ -30,6 +31,7 @@ class ShapenetDataDgcnn(object): self.return_cls_label = return_cls_label self.just_one_class = just_one_class self.norm = norm + self.augmnetation = augmnetation self.data_augmentation = data_augmentation # data operations @@ -143,6 +145,24 @@ class ShapenetDataDgcnn(object): m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) pc = pc / m return pc + + def jitter_pointcloud(self, pointcloud, sigma=0.01, clip=0.02): + N, C = pointcloud.shape + pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) + return pointcloud + + def rotate_pointcloud(self, pointcloud): + theta = np.pi*2 * np.random.uniform() + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]]) + pointcloud[:,[0,2]] = pointcloud[:,[0,2]].dot(rotation_matrix) # random rotation (x,z) + return pointcloud + + def translate_pointcloud(self, pointcloud): + xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) + xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) + + translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') + return translated_pointcloud # TODO: add the selection for a given class @@ -164,6 +184,13 @@ class ShapenetDataDgcnn(object): if self.norm: point_set = self.normalize(point_set) + if self.augmnetation: + point_set = self.jitter_pointcloud(point_set) + point_set = self.rotate_pointcloud(point_set) + point_set = self.translate_pointcloud(point_set) + + + choice = np.random.choice(len(point_set), self.npoints, replace=True) point_set = point_set[choice, :]