Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
point-transformer
Manage
Activity
Members
Plan
Wiki
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Model registry
Analyze
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Maciej Wielgosz
point-transformer
Commits
04bde564
Commit
04bde564
authored
2 years ago
by
Maciej Wielgosz
Browse files
Options
Downloads
Patches
Plain Diff
parallel implementation of dgcnn in pl
parent
36f46314
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
dgcnn/dgcnn_train_pl.py
+63
-0
63 additions, 0 deletions
dgcnn/dgcnn_train_pl.py
dgcnn/shapenet_data_dgcnn.py
+0
-2
0 additions, 2 deletions
dgcnn/shapenet_data_dgcnn.py
with
63 additions
and
2 deletions
dgcnn/dgcnn_train_pl.py
0 → 100644
+
63
−
0
View file @
04bde564
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
from
shapenet_data_dgcnn
import
ShapenetDataDgcnn
import
pytorch_lightning
as
pl
from
pytorch_lightning.loggers
import
WandbLogger
from
model
import
DGCNN
class
DGCNNLightning
(
pl
.
LightningModule
):
def
__init__
(
self
,
num_classes
):
super
().
__init__
()
self
.
dgcnn
=
DGCNN
(
num_classes
)
def
forward
(
self
,
x
):
return
self
.
dgcnn
(
x
)
def
training_step
(
self
,
batch
,
batch_idx
):
points
,
_
,
class_name
=
batch
pred
=
self
(
points
)
loss
=
F
.
cross_entropy
(
pred
,
class_name
,
reduction
=
'
mean
'
,
ignore_index
=
255
)
self
.
log
(
'
train_loss
'
,
loss
)
return
loss
def
configure_optimizers
(
self
):
optimizer
=
torch
.
optim
.
Adam
(
self
.
parameters
(),
lr
=
0.0001
)
return
optimizer
# get data
shapenet_data
=
ShapenetDataDgcnn
(
root
=
'
/home/nibio/mutable-outside-world/code/oracle_gpu_runs/data/shapenet
'
,
npoints
=
256
,
return_cls_label
=
True
,
small_data
=
False
,
small_data_size
=
1000
,
just_one_class
=
False
,
split
=
'
train
'
,
norm
=
True
)
# create a dataloader
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
shapenet_data
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
8
,
drop_last
=
True
)
# Initialize a trainer
wandb_logger
=
WandbLogger
(
project
=
"
dgcnn
"
,
name
=
"
dgcnn
"
,
entity
=
"
maciej-wielgosz-nibio
"
)
trainer
=
pl
.
Trainer
(
accelerator
=
"
auto
"
,
devices
=
[
0
],
max_epochs
=
3
,
logger
=
wandb_logger
,
gpus
=
1
)
# Initialize a model
model
=
DGCNNLightning
(
num_classes
=
16
)
wandb_logger
.
watch
(
model
)
# Train the model on gpu
trainer
.
fit
(
model
,
dataloader
)
This diff is collapsed.
Click to expand it.
dgcnn/shapenet_data_dgcnn.py
+
0
−
2
View file @
04bde564
...
...
@@ -165,8 +165,6 @@ class ShapenetDataDgcnn(object):
point_set
=
self
.
normalize
(
point_set
)
choice
=
np
.
random
.
choice
(
len
(
point_set
),
self
.
npoints
,
replace
=
True
)
# chose the first npoints
choice
=
np
.
arange
(
self
.
npoints
)
point_set
=
point_set
[
choice
,
:]
point_set
=
point_set
.
astype
(
np
.
float32
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment