Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
point-transformer
Manage
Activity
Members
Plan
Wiki
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Model registry
Analyze
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Maciej Wielgosz
point-transformer
Commits
4febaf36
Commit
4febaf36
authored
2 years ago
by
Maciej Wielgosz
Browse files
Options
Downloads
Patches
Plain Diff
transfomer for cifar10 works - problems with class embedding
parent
3e965fac
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
cifar_example/cifar_example_transformer.py
+63
-25
63 additions, 25 deletions
cifar_example/cifar_example_transformer.py
with
63 additions
and
25 deletions
cifar_example/cifar_example_transformer.py
+
63
−
25
View file @
4febaf36
...
...
@@ -8,6 +8,8 @@ import torch.nn.functional as F
import
torch.optim
as
optim
import
torchvision
from
torchvision
import
datasets
,
transforms
import
wandb
# import resnet18 from trochvision
from
torchvision.models
import
resnet18
...
...
@@ -21,10 +23,10 @@ train_data = CIFAR10(root='./data', train=True, download=True, transform=transfo
test_data
=
CIFAR10
(
root
=
'
./data
'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
ToTensor
())
# get the train loader
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_data
,
batch_size
=
1
,
shuffle
=
True
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_data
,
batch_size
=
32
,
shuffle
=
True
)
# get the test loader
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_data
,
batch_size
=
1
,
shuffle
=
False
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_data
,
batch_size
=
32
,
shuffle
=
False
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
# train the model
...
...
@@ -42,6 +44,19 @@ def train(model, device, train_loader, optimizer, epoch):
epoch
,
batch_idx
*
len
(
data
),
len
(
train_loader
.
dataset
),
100.
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
# log to wandb
wandb
.
log
({
"
loss
"
:
loss
.
item
()})
wandb
.
log
({
"
epoch
"
:
epoch
})
# compute the accuracy
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
# get the index of the max log-probability
correct
=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
accuracy
=
correct
/
len
(
data
)
# log to wandb
wandb
.
log
({
"
accuracy
"
:
accuracy
})
# test the model
def
test
(
model
,
device
,
test_loader
):
model
.
eval
()
...
...
@@ -84,22 +99,22 @@ class MyModel(nn.Module):
#### define my transformer model
# define embedding class
class
Embedding
(
nn
.
Module
):
def
__init__
(
self
,
patch_size
,
in_channels
,
out_channels
,
device
=
'
cpu
'
,
return_patches
=
False
,
extra_token
=
False
):
def
__init__
(
self
,
patch_size
,
in_channels
,
out_channels
,
return_patches
=
False
,
extra_token
=
False
):
super
(
Embedding
,
self
).
__init__
()
self
.
patch_size
=
patch_size
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
device
=
device
self
.
return_patches
=
return_patches
self
.
classify
=
extra_token
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
norm
=
nn
.
LayerNorm
(
out_channels
)
self
.
extra_token
=
None
# initialize extra_token tensor
def
get_patches
(
self
,
x
,
patch_size
=
8
):
# get the patches
patches
=
x
.
unfold
(
2
,
patch_size
,
patch_size
).
unfold
(
3
,
patch_size
,
patch_size
)
patches
=
x
.
unfold
(
2
,
patch_size
,
patch_size
).
unfold
(
3
,
patch_size
,
patch_size
)
.
to
(
x
.
device
)
return
patches
...
...
@@ -121,12 +136,6 @@ class Embedding(nn.Module):
patches
=
self
.
get_patches
(
x
,
patch_size
=
self
.
patch_size
)
# flatten the patches
patches
=
patches
.
reshape
(
-
1
,
self
.
in_channels
,
self
.
patch_size
,
self
.
patch_size
)
# add extra embedding token if classification is needed
if
self
.
classify
:
self
.
extra_token
=
torch
.
rand
(
1
,
self
.
in_channels
,
self
.
patch_size
,
self
.
patch_size
)
self
.
extra_token
=
self
.
extra_token
.
to
(
x
.
device
)
# move extra_token to the same device as x
patches
=
torch
.
cat
((
self
.
extra_token
,
patches
),
dim
=
0
)
# get the embedding
embedding
=
self
.
conv
(
patches
)
# flatten the embedding
...
...
@@ -135,14 +144,22 @@ class Embedding(nn.Module):
embedding
=
self
.
norm
(
embedding
)
# add the positional encoding
pos_encoding
=
self
.
get_pos_encoding
(
self
.
out_channels
,
embedding
.
shape
[
0
])
pos_encoding
=
pos_encoding
.
to
(
x
.
device
)
embedding
=
embedding
+
pos_encoding
embedding
=
embedding
+
pos_encoding
.
to
(
x
.
device
)
# reshape the embedding
embedding
=
embedding
.
reshape
(
x
.
shape
[
0
],
-
1
,
self
.
out_channels
)
if
self
.
classify
:
# add the classification token
classification_token
=
torch
.
rand
(
x
.
shape
[
0
],
1
,
self
.
out_channels
).
to
(
x
.
device
)
embedding
=
torch
.
cat
((
classification_token
,
embedding
),
dim
=
1
)
if
self
.
return_patches
:
return
embedding
,
patches
else
:
return
embedding
# define transformer class
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
):
...
...
@@ -178,26 +195,28 @@ from torch.nn import TransformerEncoderLayer
class
PthBasedTransformer
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
embedding_size
=
64
)
->
None
:
super
().
__init__
()
self
.
embedding
=
Embedding
(
patch_size
=
8
,
in_channels
=
3
,
out_channels
=
8
,
return_patches
=
True
,
extra_token
=
True
)
self
.
self_attention
=
TransformerEncoderLayer
(
d_model
=
8
,
nhead
=
8
)
self
.
fc
=
nn
.
Linear
(
8
,
10
)
self
.
embedding
=
Embedding
(
patch_size
=
16
,
in_channels
=
3
,
out_channels
=
embedding_size
,
return_patches
=
True
,
extra_token
=
True
)
self
.
self_attention
=
TransformerEncoderLayer
(
d_model
=
embedding_size
,
nhead
=
16
,
dim_feedforward
=
embedding_size
*
4
,
dropout
=
0.3
)
self
.
fc
=
nn
.
Linear
(
embedding_size
,
10
)
def
forward
(
self
,
x
):
embedding
,
patches
=
self
.
embedding
(
x
)
context
=
self
.
self_attention
(
embedding
)
# get the first token
context
=
context
[:,
0
,
:]
# context = context.mean(dim=1)
# get the classification
context
=
self
.
fc
(
context
)
return
context
...
...
@@ -260,10 +279,29 @@ def main(train_model=False, model_type="resnet"):
if
__name__
==
'
__main__
'
:
train_model
=
True
model_type
=
"
cnn
"
#
model_type = "cnn"
# model_type = "resnet"
model_type
=
"
pth_transformer
"
# Create a config object for wandb
config
=
{
'
model_type
'
:
model_type
,
'
batch_size
'
:
64
,
'
test_batch_size
'
:
1000
,
'
epochs
'
:
10
,
'
lr
'
:
0.01
,
'
momentum
'
:
0.5
,
'
no_cuda
'
:
False
,
'
seed
'
:
1
,
'
log_interval
'
:
10
,
'
patch_size
'
:
8
,
'
in_channels
'
:
3
,
'
out_channels
'
:
64
,
'
extra_token
'
:
True
}
# add model type to wandb
wandb
.
init
(
project
=
"
cifar10_example_transformer
"
,
entity
=
"
maciej-wielgosz-nibio
"
,
config
=
config
)
main
(
train_model
=
train_model
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment