Skip to content
Snippets Groups Projects
Commit 86333368 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

updated the loader

parent c32b0aaf
No related branches found
No related tags found
No related merge requests found
......@@ -130,5 +130,7 @@ data/
shapenet_part_seg_hdf5_data/
ShapeNet/
maciek_data
nibio_data
nibio_data_no_commas
```
......@@ -19,12 +19,14 @@ class MyData(InMemoryDataset):
def __init__(self,
root: str,
label_location: int = -1,
include_normals: bool = True,
split: str = 'trainval',
split_ratio: tuple = (0.7, 0.15, 0.15),
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None):
self.label_location = label_location
self.split_ratio = split_ratio
super().__init__(root, transform, pre_transform, pre_filter)
......@@ -87,7 +89,7 @@ class MyData(InMemoryDataset):
data = read_txt_array(osp.join(self.raw_dir, name))
pos = data[:, :3]
x = data[:, 3:6]
y = data[:, -1].type(torch.long)
y = data[:, self.label_location].type(torch.long)
category = torch.tensor(0, dtype=torch.long) # there is only one category !!
data = Data(pos=pos, x=x, y=y, category=category)
if self.pre_filter is not None and not self.pre_filter(data):
......
......@@ -11,11 +11,12 @@ from torch_scatter import scatter
from torchmetrics.functional import jaccard_index
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, DynamicEdgeConv
from my_data_loader import MyData
wandb_project = "pyg-point-cloud" #@param {"type": "string"} , maciej-wielgosz-nibio
wandb_run_name = "train-dgcnn" #@param {"type": "string"}
......@@ -35,13 +36,12 @@ random.seed(config.seed)
torch.manual_seed(config.seed)
device = torch.device(config.device)
config.category = 'Car' #@param ["Bag", "Cap", "Car", "Chair", "Earphone", "Guitar", "Knife", "Lamp", "Laptop", "Motorbike", "Mug", "Pistol", "Rocket", "Skateboard", "Table"] {type:"raw"}
config.random_jitter_translation = 1e-2
config.random_rotation_interval_x = 15
config.random_rotation_interval_y = 15
config.random_rotation_interval_z = 15
config.validation_split = 0.2
config.batch_size = 4
config.batch_size = 1
config.num_workers = 6
config.num_nearest_neighbours = 30
......@@ -62,23 +62,10 @@ transform = T.Compose([
])
pre_transform = T.NormalizeScale()
# dataset_path = os.path.join('ShapeNet', config.category)
# train_val_dataset = ShapeNet(
# dataset_path, config.category, split='trainval',
# transform=transform, pre_transform=pre_transform
# )
from my_data_loader import MyData
dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/maciek_data/plane_maciek"
# dataset_path = "/home/nibio/mutable-outside-world/code/nibio_graph_sem_seg/nibio_data_no_commas"
train_val_dataset = MyData(
dataset_path, config.category, split='trainval',
transform=transform, pre_transform=pre_transform
)
train_val_dataset = MyData(dataset_path, split='trainval', transform=transform, pre_transform=pre_transform)
segmentation_class_frequency = {}
......@@ -111,7 +98,6 @@ visualization_loader = DataLoader(
)
class DGCNN(torch.nn.Module):
def __init__(self, out_channels, k=30, aggr='max'):
super().__init__()
......@@ -197,9 +183,6 @@ def train_step(epoch):
iou = torch.tensor(ious, device=device)
category = torch.cat(categories, dim=0)
print("iou shape:", iou.shape)
print("category shape:", category.shape)
mean_iou = float(scatter(iou, category, reduce='mean').mean())
......@@ -236,10 +219,8 @@ def val_step(epoch):
total_nodes += data.num_nodes
sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
data.category.tolist()):
category = list(ShapeNet.seg_classes.keys())[category]
part = ShapeNet.seg_classes[category]
for out, y in zip(outs.split(sizes), data.y.split(sizes)):
part = MyData.seg_classes
part = torch.tensor(part, device=device)
y_map[part] = torch.arange(part.size(0), device=device)
......@@ -278,12 +259,10 @@ def visualization_step(epoch, table):
y_map = torch.empty(
visualization_loader.dataset.num_classes, device=device
).long()
for out, y, category in zip(
outs.split(sizes), data.y.split(sizes), data.category.tolist()
):
category = list(ShapeNet.seg_classes.keys())[category]
part = ShapeNet.seg_classes[category]
for out, y in zip(outs.split(sizes), data.y.split(sizes)):
part = MyData.seg_classes
part = torch.tensor(part, device=device)
y_map[part] = torch.arange(part.size(0), device=device)
iou = jaccard_index(
out[:, part].argmax(dim=-1), y_map[y],
......
import os
import argparse
from tqdm import tqdm
def remove_commas(source_folder, target_folder):
# Create target folder if it doesn't exist
os.makedirs(target_folder, exist_ok=True)
# Iterate over all txt files in the source folder
for filename in os.listdir(source_folder):
if filename.endswith(".txt"):
# Open the source file and the target file
with open(os.path.join(source_folder, filename), 'r') as source_file, \
open(os.path.join(target_folder, filename), 'w') as target_file:
# Read each line from the source file
for line in tqdm(source_file):
# Replace commas with nothing
line_without_commas = line.replace(',', ' ')
# Write the new line to the target file
target_file.write(line_without_commas)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--source_folder", help="Folder with files that contain commas")
parser.add_argument("--target_folder", help="Folder with files that don't contain commas")
args = parser.parse_args()
remove_commas(args.source_folder, args.target_folder)
print(f"Files without commas are saved in : {args.target_folder}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment