Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
point-transformer
Manage
Activity
Members
Plan
Wiki
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Model registry
Analyze
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Maciej Wielgosz
point-transformer
Commits
0441820e
Commit
0441820e
authored
2 years ago
by
Maciej Wielgosz
Browse files
Options
Downloads
Patches
Plain Diff
update of pl forest transformer
parent
8338af5b
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
forest_sem_seg_transformer_pl.py
+143
-12
143 additions, 12 deletions
forest_sem_seg_transformer_pl.py
with
143 additions
and
12 deletions
forest_sem_seg_transformer_pl.py
+
143
−
12
View file @
0441820e
import
argparse
import
importlib
import
os
import
shutil
import
hydra
import
numpy
as
np
import
omegaconf
...
...
@@ -45,6 +47,18 @@ class ForestSemSegTransformer(pl.LightningModule):
for
label
in
self
.
seg_classes
[
cat
]:
self
.
seg_label_to_cat
[
label
]
=
cat
self
.
results_dir
=
hydra
.
utils
.
to_absolute_path
(
'
results
'
)
# create folder to save the results las files
if
not
os
.
path
.
exists
(
self
.
results_dir
):
os
.
mkdir
(
self
.
results_dir
)
self
.
test_dataset
=
Dataset
(
root
=
hydra
.
utils
.
to_absolute_path
(
'
data/forest_txt/validation_txt/
'
),
npoints
=
self
.
conf
.
num_point
,
normal_channel
=
self
.
conf
.
normal
,
normalize_point_cloud
=
False
)
def
forward
(
self
,
data
):
return
self
.
model
(
data
)
...
...
@@ -72,12 +86,79 @@ class ForestSemSegTransformer(pl.LightningModule):
else
:
optimizer
=
torch
.
optim
.
SGD
(
self
.
model
.
parameters
(),
lr
=
self
.
conf
.
learning_rate
,
momentum
=
0.9
)
# update learning rate
lr
=
max
(
self
.
conf
.
learning_rate
*
(
self
.
conf
.
lr_decay
**
(
self
.
current_epoch
//
self
.
conf
.
step_size
)),
self
.
learning_rate_clip
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'
lr
'
]
=
lr
return
optimizer
def
random_scale_point_cloud_pth
(
self
,
batch_data
,
scale_low
=
0.8
,
scale_high
=
1.25
):
"""
Randomly scale the point cloud. Scale is per point cloud.
Input:
batch_data: (B, N, 3) tensor, original batch of point clouds
scale_low: float, lower bound of the random scale factor
scale_high: float, upper bound of the random scale factor
Return:
(B, N, 3) tensor, scaled batch of point clouds
"""
B
,
N
,
C
=
batch_data
.
size
()
scales
=
torch
.
FloatTensor
(
B
).
uniform_
(
scale_low
,
scale_high
)
scales
=
scales
.
view
(
B
,
1
,
1
)
# Reshape for broadcasting
batch_data
*=
scales
.
to
(
batch_data
.
device
)
return
batch_data
def
shift_point_cloud_pth
(
self
,
batch_data
,
shift_range
=
0.1
):
"""
Randomly shift point cloud. Shift is per point cloud.
Input:
batch_data: (B, N, 3) tensor, original batch of point clouds
shift_range: float, maximum distance to shift each point
Return:
(B, N, 3) tensor, shifted batch of point clouds
"""
B
,
N
,
C
=
batch_data
.
size
()
shifts
=
torch
.
FloatTensor
(
B
,
1
,
3
).
uniform_
(
-
shift_range
,
shift_range
)
shifts
=
shifts
.
to
(
batch_data
.
device
)
batch_data
+=
shifts
return
batch_data
def
shift_point_cloud_np
(
self
,
batch_data
,
shift_range
=
0.1
):
"""
Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
# convert to numpy array
batch_data
=
batch_data
.
cpu
().
numpy
()
B
,
N
,
C
=
batch_data
.
shape
shifts
=
np
.
random
.
uniform
(
-
shift_range
,
shift_range
,
(
B
,
3
))
for
batch_index
in
range
(
B
):
batch_data
[
batch_index
,:,:]
+=
shifts
[
batch_index
,:]
# convert back to torch tensor
batch_data
=
torch
.
from_numpy
(
batch_data
).
cuda
()
return
batch_data
def
random_scale_point_cloud_np
(
self
,
batch_data
,
scale_low
=
0.8
,
scale_high
=
1.25
):
"""
Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
# convert to numpy array
batch_data
=
batch_data
.
cpu
().
numpy
()
B
,
N
,
C
=
batch_data
.
shape
scales
=
np
.
random
.
uniform
(
scale_low
,
scale_high
,
B
)
for
batch_index
in
range
(
B
):
batch_data
[
batch_index
,:,:]
*=
scales
[
batch_index
]
# convert back to torch tensor
batch_data
=
torch
.
from_numpy
(
batch_data
).
cuda
()
return
batch_data
# define epoch se
def
on_train_epoch_start
(
self
)
->
None
:
...
...
@@ -89,17 +170,25 @@ class ForestSemSegTransformer(pl.LightningModule):
self
.
model
=
self
.
model
.
train
()
self
.
acc_train_mean_correct
=
[]
# update learning rate
lr
=
max
(
self
.
conf
.
learning_rate
*
(
self
.
conf
.
lr_decay
**
(
self
.
current_epoch
//
self
.
conf
.
step_size
)),
self
.
learning_rate_clip
)
for
param_group
in
self
.
optimizers
().
param_groups
:
param_group
[
'
lr
'
]
=
lr
def
on_training_epoch_end
(
self
):
self
.
log
(
"
train_acc
"
,
np
.
mean
(
self
.
acc_train_mean_correct
),
on_step
=
True
,
on_epoch
=
True
,
prog_bar
=
True
,
logger
=
True
)
self
.
log
(
"
epoch
"
,
self
.
current_epoch
,
on_step
=
True
,
on_epoch
=
True
,
prog_bar
=
True
,
logger
=
True
)
def
training_step
(
self
,
batch
,
batch_idx
):
points
,
label
=
batch
points
[:,
:,
0
:
3
]
=
provider
.
random_scale_point_cloud
(
points
[:,
:,
0
:
3
])
points
[:,
:,
0
:
3
]
=
provider
.
shift_point_cloud
(
points
[:,
:,
0
:
3
])
# points[:, :, 0:3] = self.random_scale_point_cloud_np(points[:, :, 0:3])
points
[:,
:,
0
:
3
]
=
self
.
random_scale_point_cloud_pth
(
points
[:,
:,
0
:
3
])
# points[:, :, 0:3] = self.shift_point_cloud_np(points[:, :, 0:3])
points
[:,
:,
0
:
3
]
=
self
.
shift_point_cloud_pth
(
points
[:,
:,
0
:
3
])
points
=
torch
.
Tensor
(
points
)
points
,
label
=
points
.
float
().
cuda
(),
label
.
long
().
cuda
()
...
...
@@ -107,6 +196,7 @@ class ForestSemSegTransformer(pl.LightningModule):
[
points
,
self
.
to_categorical
(
torch
.
ones
((
points
.
shape
[
0
],
1
),
dtype
=
torch
.
float16
).
cuda
(),
16
).
repeat
(
1
,
points
.
shape
[
1
],
1
)],
-
1
))
seg_pred
=
seg_pred
.
contiguous
().
view
(
-
1
,
self
.
conf
.
num_part
)
target
=
label
.
view
(
-
1
,
1
)[:,
0
]
pred_choice
=
seg_pred
.
data
.
max
(
1
)[
1
]
...
...
@@ -172,6 +262,31 @@ class ForestSemSegTransformer(pl.LightningModule):
logits
=
cur_pred_val_logits
[
i
,
:,
:]
cur_pred_val
[
i
,
:]
=
np
.
argmax
(
logits
[:,
self
.
seg_classes
[
cat
]],
1
)
+
self
.
seg_classes
[
cat
][
0
]
# get x,y,z coordinates of points
points
=
points
.
cpu
().
data
.
numpy
()
points
=
points
[:,
:,
0
:
3
]
points_pd
=
np
.
concatenate
([
points
,
np
.
expand_dims
(
cur_pred_val
,
axis
=
2
)],
axis
=
2
)
points_gt
=
np
.
concatenate
([
points
,
np
.
expand_dims
(
target
,
axis
=
2
)],
axis
=
2
)
# save points as text files in the results folder and preserve the same name as the original txt file
for
i
in
range
(
cur_batch_size
):
np
.
savetxt
(
os
.
path
.
join
(
self
.
results_dir
,
self
.
test_dataset
.
datapath
[
batch_idx
*
self
.
conf
.
batch_size
+
i
].
split
(
'
/
'
)[
-
1
].
replace
(
'
.txt
'
,
'
_pred.txt
'
)
),
points_pd
[
i
,
:,
:],
fmt
=
'
%f %f %f %d
'
)
# copy original txt file to the results folder
shutil
.
copy
(
self
.
test_dataset
.
datapath
[
batch_idx
*
self
.
conf
.
batch_size
+
i
],
os
.
path
.
join
(
self
.
results_dir
,
self
.
test_dataset
.
datapath
[
batch_idx
*
self
.
conf
.
batch_size
+
i
].
split
(
'
/
'
)[
-
1
]
))
# save ground truth labels as text files in the results folder and preserve the same name as the original txt file
np
.
savetxt
(
os
.
path
.
join
(
self
.
results_dir
,
self
.
test_dataset
.
datapath
[
batch_idx
*
self
.
conf
.
batch_size
+
i
].
split
(
'
/
'
)[
-
1
].
replace
(
'
.txt
'
,
'
_gt.txt
'
)
),
points_gt
[
i
,
:,
:],
fmt
=
'
%f %f %f %d
'
)
correct
=
np
.
sum
(
cur_pred_val
==
target
)
self
.
total_correct
+=
correct
self
.
total_seen
+=
(
cur_batch_size
*
NUM_POINT
)
...
...
@@ -196,8 +311,6 @@ class ForestSemSegTransformer(pl.LightningModule):
return
self
.
test_metrics
class
ForestDataset
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
conf
):
super
().
__init__
()
...
...
@@ -208,8 +321,18 @@ class ForestDataset(pl.LightningDataModule):
self
.
train_dataset
=
None
def
setup
(
self
,
stage
=
None
):
self
.
test_dataset
=
Dataset
(
root
=
self
.
test_dataset_path
,
npoints
=
self
.
conf
.
num_point
,
normal_channel
=
self
.
conf
.
normal
,
normalize_point_cloud
=
True
)
self
.
train_dataset
=
Dataset
(
root
=
self
.
train_dataset_path
,
npoints
=
self
.
conf
.
num_point
,
normal_channel
=
self
.
conf
.
normal
,
normalize_point_cloud
=
True
)
self
.
test_dataset
=
Dataset
(
root
=
self
.
test_dataset_path
,
npoints
=
self
.
conf
.
num_point
,
normal_channel
=
self
.
conf
.
normal
,
normalize_point_cloud
=
self
.
conf
.
normalize_point_cloud
)
self
.
train_dataset
=
Dataset
(
root
=
self
.
train_dataset_path
,
npoints
=
self
.
conf
.
num_point
,
normal_channel
=
self
.
conf
.
normal
,
normalize_point_cloud
=
True
)
def
train_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
...
...
@@ -232,10 +355,18 @@ class ForestDataset(pl.LightningDataModule):
def
main
(
args
):
omegaconf
.
OmegaConf
.
set_struct
(
args
,
False
)
# add a parameter to args
trainer
=
pl
.
Trainer
(
gpus
=
1
,
max_epochs
=
args
.
epoch
)
# args.normalize_point_cloud = True
# trainer.fit(ForestSemSegTransformer(args), ForestDataset(args))
trainer
.
validate
(
ForestSemSegTransformer
(
args
),
ForestDataset
(
args
))
args
.
normalize_point_cloud
=
True
trainer
.
validate
(
ForestSemSegTransformer
(
args
),
ForestDataset
(
args
),
ckpt_path
=
'
/home/nibio/mutable-outside-world/code/oracle_gpu_runs/log/partseg/Hengshuang/lightning_logs/version_71/checkpoints/epoch=9-step=30.ckpt
'
)
if
__name__
==
'
__main__
'
:
main
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment