diff --git a/PyG_implementation/my_data_loader.py b/PyG_implementation/my_data_loader.py index 37ce63dac57d8bf591e8cabaaf19a33234eb197a..2de8ed2f630098d31e7b39a12899676fe3835d86 100644 --- a/PyG_implementation/my_data_loader.py +++ b/PyG_implementation/my_data_loader.py @@ -11,7 +11,7 @@ from torch_geometric.data import ( Data, InMemoryDataset ) -from torch_geometric.io import read_txt_array +from torch_geometric.io import parse_txt_array class MyData(InMemoryDataset): @@ -82,11 +82,28 @@ class MyData(InMemoryDataset): data_list = self.process_filenames(split_filenames) torch.save(self.collate(data_list), self.processed_paths[i]) + def read_txt_array(self, path, sep=',', start=0, end=None, dtype=None, device=None): + with open(path, 'r') as f: + src = f.read().split('\n') + + # Check if the first line is a header + if any(c.isalpha() for c in src[0]): + # If it's a header, remove it from src + src = src[1:] + + src = src[:-1] # Exclude the last line, as in your original code + + # # Split each line into a list of numbers + # src = [line.split(sep) for line in src] + + return parse_txt_array(src, sep, start, end, dtype, device) + + def process_filenames(self, filenames): data_list = [] for name in filenames: - data = read_txt_array(osp.join(self.raw_dir, name)) + data = self.read_txt_array(osp.join(self.raw_dir, name)) pos = data[:, :3] x = data[:, 3:6] y = data[:, self.label_location].type(torch.long)