import glob import json import os import h5py import cv2 import numpy as np from torch.utils.data import Dataset def download_shapenetpart(): BASE_DIR = os.path.dirname(os.path.abspath(__file__)) DATA_DIR = os.path.join(BASE_DIR, 'data') if not os.path.exists(DATA_DIR): os.mkdir(DATA_DIR) if not os.path.exists(os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data')): www = 'https://shapenet.cs.stanford.edu/media/shapenet_part_seg_hdf5_data.zip' zipfile = os.path.basename(www) os.system('wget --no-check-certificate %s; unzip %s' % (www, zipfile)) os.system('mv %s %s' % ('hdf5_data', os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data'))) os.system('rm %s' % (zipfile)) def load_data_partseg(partition): download_shapenetpart() BASE_DIR = os.path.dirname(os.path.abspath(__file__)) DATA_DIR = os.path.join(BASE_DIR, 'data') all_data = [] all_label = [] all_seg = [] if partition == 'trainval': file = glob.glob(os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data', '*train*.h5')) \ + glob.glob(os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data', '*val*.h5')) else: file = glob.glob(os.path.join(DATA_DIR, 'shapenet_part_seg_hdf5_data', '*%s*.h5'%partition)) for h5_name in file: f = h5py.File(h5_name, 'r+') data = f['data'][:].astype('float32') label = f['label'][:].astype('int64') seg = f['pid'][:].astype('int64') f.close() all_data.append(data) all_label.append(label) all_seg.append(seg) all_data = np.concatenate(all_data, axis=0) all_label = np.concatenate(all_label, axis=0) all_seg = np.concatenate(all_seg, axis=0) return all_data, all_label, all_seg def load_color_partseg(): colors = [] labels = [] f = open("prepare_data/meta/partseg_colors.txt") for line in json.load(f): colors.append(line['color']) labels.append(line['label']) partseg_colors = np.array(colors) partseg_colors = partseg_colors[:, [2, 1, 0]] partseg_labels = np.array(labels) font = cv2.FONT_HERSHEY_SIMPLEX img_size = 1350 img = np.zeros((1350, 1890, 3), dtype="uint8") cv2.rectangle(img, (0, 0), (1900, 1900), [255, 255, 255], thickness=-1) column_numbers = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3] column_gaps = [320, 320, 300, 300, 285, 285] color_size = 64 color_index = 0 label_index = 0 row_index = 16 for row in range(0, img_size): column_index = 32 for column in range(0, img_size): color = partseg_colors[color_index] label = partseg_labels[label_index] length = len(str(label)) cv2.rectangle(img, (column_index, row_index), (column_index + color_size, row_index + color_size), color=(int(color[0]), int(color[1]), int(color[2])), thickness=-1) img = cv2.putText(img, label, (column_index + int(color_size * 1.15), row_index + int(color_size / 2)), font, 0.76, (0, 0, 0), 2) column_index = column_index + column_gaps[column] color_index = color_index + 1 label_index = label_index + 1 if color_index >= 50: cv2.imwrite("prepare_data/meta/partseg_colors.png", img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) return np.array(colors) elif (column + 1 >= column_numbers[row]): break row_index = row_index + int(color_size * 1.3) if (row_index >= img_size): break class ShapeNetPart(Dataset): def __init__(self, num_points, partition='train', class_choice=None): self.data, self.label, self.seg = load_data_partseg(partition) self.cat2id = {'airplane': 0, 'bag': 1, 'cap': 2, 'car': 3, 'chair': 4, 'earphone': 5, 'guitar': 6, 'knife': 7, 'lamp': 8, 'laptop': 9, 'motor': 10, 'mug': 11, 'pistol': 12, 'rocket': 13, 'skateboard': 14, 'table': 15} self.seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3] self.index_start = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47] self.num_points = num_points self.partition = partition self.class_choice = class_choice self.partseg_colors = load_color_partseg() if self.class_choice != None: id_choice = self.cat2id[self.class_choice] indices = (self.label == id_choice).squeeze() self.data = self.data[indices] self.label = self.label[indices] self.seg = self.seg[indices] self.seg_num_all = self.seg_num[id_choice] self.seg_start_index = self.index_start[id_choice] else: self.seg_num_all = 50 self.seg_start_index = 0 def __getitem__(self, item): pointcloud = self.data[item][:self.num_points] label = self.label[item] seg = self.seg[item][:self.num_points] if self.partition == 'trainval': # pointcloud = translate_pointcloud(pointcloud) indices = list(range(pointcloud.shape[0])) np.random.shuffle(indices) pointcloud = pointcloud[indices] seg = seg[indices] return pointcloud, label, seg def __len__(self): return self.data.shape[0]