From 41a20aa6f86eeda24b96dd0ff3a324aac999b78c Mon Sep 17 00:00:00 2001 From: Maciej Wielgosz <maciej.wielgosz@nibio.no> Date: Tue, 6 Jun 2023 12:30:20 +0200 Subject: [PATCH] updated flow for splitting and merging --- PyG_implementation/my_data_loader.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/PyG_implementation/my_data_loader.py b/PyG_implementation/my_data_loader.py index 37ce63d..2de8ed2 100644 --- a/PyG_implementation/my_data_loader.py +++ b/PyG_implementation/my_data_loader.py @@ -11,7 +11,7 @@ from torch_geometric.data import ( Data, InMemoryDataset ) -from torch_geometric.io import read_txt_array +from torch_geometric.io import parse_txt_array class MyData(InMemoryDataset): @@ -82,11 +82,28 @@ class MyData(InMemoryDataset): data_list = self.process_filenames(split_filenames) torch.save(self.collate(data_list), self.processed_paths[i]) + def read_txt_array(self, path, sep=',', start=0, end=None, dtype=None, device=None): + with open(path, 'r') as f: + src = f.read().split('\n') + + # Check if the first line is a header + if any(c.isalpha() for c in src[0]): + # If it's a header, remove it from src + src = src[1:] + + src = src[:-1] # Exclude the last line, as in your original code + + # # Split each line into a list of numbers + # src = [line.split(sep) for line in src] + + return parse_txt_array(src, sep, start, end, dtype, device) + + def process_filenames(self, filenames): data_list = [] for name in filenames: - data = read_txt_array(osp.join(self.raw_dir, name)) + data = self.read_txt_array(osp.join(self.raw_dir, name)) pos = data[:, :3] x = data[:, 3:6] y = data[:, self.label_location].type(torch.long) -- GitLab