diff --git a/.gitignore b/.gitignore index 29dc5f63fdf56545bca96f32d2f97aa75d6ab191..0f2e01bf277d36d2bc642b01704ebc3491e67b53 100644 --- a/.gitignore +++ b/.gitignore @@ -130,5 +130,7 @@ data/ shapenet_part_seg_hdf5_data/ ShapeNet/ maciek_data +nibio_data +nibio_data_no_commas ``` diff --git a/PyG_implementation/my_data_loader.py b/PyG_implementation/my_data_loader.py index cb1c7846ce5e441fe35ff8dd8d951e49eebe11ee..37ce63dac57d8bf591e8cabaaf19a33234eb197a 100644 --- a/PyG_implementation/my_data_loader.py +++ b/PyG_implementation/my_data_loader.py @@ -19,12 +19,14 @@ class MyData(InMemoryDataset): def __init__(self, root: str, + label_location: int = -1, include_normals: bool = True, split: str = 'trainval', split_ratio: tuple = (0.7, 0.15, 0.15), transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None): + self.label_location = label_location self.split_ratio = split_ratio super().__init__(root, transform, pre_transform, pre_filter) @@ -87,7 +89,7 @@ class MyData(InMemoryDataset): data = read_txt_array(osp.join(self.raw_dir, name)) pos = data[:, :3] x = data[:, 3:6] - y = data[:, -1].type(torch.long) + y = data[:, self.label_location].type(torch.long) category = torch.tensor(0, dtype=torch.long) # there is only one category !! data = Data(pos=pos, x=x, y=y, category=category) if self.pre_filter is not None and not self.pre_filter(data): diff --git a/PyG_implementation/pyg_implementaion_main_my_data_loader.py b/PyG_implementation/pyg_implementaion_main_my_data_loader.py index 63aed1992df033ba93bc4d5e6952551919c12322..564fdccd513684c84d4057699f29f2e91a1cded1 100644 --- a/PyG_implementation/pyg_implementaion_main_my_data_loader.py +++ b/PyG_implementation/pyg_implementaion_main_my_data_loader.py @@ -11,11 +11,12 @@ from torch_scatter import scatter from torchmetrics.functional import jaccard_index import torch_geometric.transforms as T -from torch_geometric.datasets import ShapeNet from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, DynamicEdgeConv +from my_data_loader import MyData + wandb_project = "pyg-point-cloud" #@param {"type": "string"} , maciej-wielgosz-nibio wandb_run_name = "train-dgcnn" #@param {"type": "string"} @@ -35,13 +36,12 @@ random.seed(config.seed) torch.manual_seed(config.seed) device = torch.device(config.device) -config.category = 'Car' #@param ["Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"] {type:"raw"} config.random_jitter_translation = 1e-2 config.random_rotation_interval_x = 15 config.random_rotation_interval_y = 15 config.random_rotation_interval_z = 15 config.validation_split = 0.2 -config.batch_size = 4 +config.batch_size = 1 config.num_workers = 6 config.num_nearest_neighbours = 30 @@ -62,23 +62,10 @@ transform = T.Compose([ ]) pre_transform = T.NormalizeScale() - -# dataset_path = os.path.join('ShapeNet', config.category) - - -# train_val_dataset = ShapeNet( -# dataset_path, config.category, split='trainval', -# transform=transform, pre_transform=pre_transform -# ) - -from my_data_loader import MyData - dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/maciek_data/plane_maciek" +# dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/nibio_data_no_commas" -train_val_dataset = MyData( - dataset_path, config.category, split='trainval', - transform=transform, pre_transform=pre_transform -) +train_val_dataset = MyData(dataset_path, split='trainval', transform=transform, pre_transform=pre_transform) segmentation_class_frequency = {} @@ -111,7 +98,6 @@ visualization_loader = DataLoader( ) - class DGCNN(torch.nn.Module): def __init__(self, out_channels, k=30, aggr='max'): super().__init__() @@ -197,9 +183,6 @@ def train_step(epoch): iou = torch.tensor(ious, device=device) category = torch.cat(categories, dim=0) - print("iou shape:", iou.shape) - print("category shape:", category.shape) - mean_iou = float(scatter(iou, category, reduce='mean').mean()) @@ -236,10 +219,8 @@ def val_step(epoch): total_nodes += data.num_nodes sizes = (data.ptr[1:] - data.ptr[:-1]).tolist() - for out, y, category in zip(outs.split(sizes), data.y.split(sizes), - data.category.tolist()): - category = list(ShapeNet.seg_classes.keys())[category] - part = ShapeNet.seg_classes[category] + for out, y in zip(outs.split(sizes), data.y.split(sizes)): + part = MyData.seg_classes part = torch.tensor(part, device=device) y_map[part] = torch.arange(part.size(0), device=device) @@ -278,12 +259,10 @@ def visualization_step(epoch, table): y_map = torch.empty( visualization_loader.dataset.num_classes, device=device ).long() - for out, y, category in zip( - outs.split(sizes), data.y.split(sizes), data.category.tolist() - ): - category = list(ShapeNet.seg_classes.keys())[category] - part = ShapeNet.seg_classes[category] + for out, y in zip(outs.split(sizes), data.y.split(sizes)): + part = MyData.seg_classes part = torch.tensor(part, device=device) + y_map[part] = torch.arange(part.size(0), device=device) iou = jaccard_index( out[:, part].argmax(dim=-1), y_map[y], diff --git a/PyG_implementation/remove_commas.py b/PyG_implementation/remove_commas.py new file mode 100644 index 0000000000000000000000000000000000000000..0cafefcf250cb0cb21f18327386716e289e1890c --- /dev/null +++ b/PyG_implementation/remove_commas.py @@ -0,0 +1,31 @@ +import os +import argparse +from tqdm import tqdm + +def remove_commas(source_folder, target_folder): + # Create target folder if it doesn't exist + os.makedirs(target_folder, exist_ok=True) + + # Iterate over all txt files in the source folder + for filename in os.listdir(source_folder): + if filename.endswith(".txt"): + # Open the source file and the target file + with open(os.path.join(source_folder, filename), 'r') as source_file, \ + open(os.path.join(target_folder, filename), 'w') as target_file: + # Read each line from the source file + for line in tqdm(source_file): + # Replace commas with nothing + line_without_commas = line.replace(',', ' ') + # Write the new line to the target file + target_file.write(line_without_commas) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--source_folder", help="Folder with files that contain commas") + parser.add_argument("--target_folder", help="Folder with files that don't contain commas") + args = parser.parse_args() + + remove_commas(args.source_folder, args.target_folder) + + print(f"Files without commas are saved in : {args.target_folder}") + \ No newline at end of file