Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
point-transformer
Manage
Activity
Members
Plan
Wiki
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Model registry
Analyze
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Maciej Wielgosz
point-transformer
Commits
3e965fac
Commit
3e965fac
authored
2 years ago
by
Maciej Wielgosz
Browse files
Options
Downloads
Patches
Plain Diff
transformer cifar works for a single image
parent
2406277f
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
cifar_example/cifar_example_transformer.py
+136
-27
136 additions, 27 deletions
cifar_example/cifar_example_transformer.py
with
136 additions
and
27 deletions
cifar_example/cifar_example_transformer.py
+
136
−
27
View file @
3e965fac
# simple cifar10 example using resnet
import
os
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -20,10 +21,11 @@ train_data = CIFAR10(root='./data', train=True, download=True, transform=transfo
test_data
=
CIFAR10
(
root
=
'
./data
'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
ToTensor
())
# get the train loader
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_data
,
batch_size
=
64
,
shuffle
=
True
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_data
,
batch_size
=
1
,
shuffle
=
True
)
# get the test loader
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_data
,
batch_size
=
64
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_data
,
batch_size
=
1
,
shuffle
=
False
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
# train the model
def
train
(
model
,
device
,
train_loader
,
optimizer
,
epoch
):
...
...
@@ -32,7 +34,7 @@ def train(model, device, train_loader, optimizer, epoch):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
F
.
cross_entropy
(
output
,
target
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
if
batch_idx
%
100
==
0
:
...
...
@@ -82,38 +84,135 @@ class MyModel(nn.Module):
#### define my transformer model
class
Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
vocab
):
super
(
Embeddings
,
self
).
__init__
()
self
.
lut
=
nn
.
Embedding
(
vocab
,
d_model
)
self
.
d_model
=
d_model
class
Embedding
(
nn
.
Module
):
def
__init__
(
self
,
patch_size
,
in_channels
,
out_channels
,
device
=
'
cpu
'
,
return_patches
=
False
,
extra_token
=
False
):
super
(
Embedding
,
self
).
__init__
()
self
.
patch_size
=
patch_size
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
device
=
device
self
.
return_patches
=
return_patches
self
.
classify
=
extra_token
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
norm
=
nn
.
LayerNorm
(
out_channels
)
self
.
extra_token
=
None
# initialize extra_token tensor
def
_get_patches
(
self
,
x
):
# get the patches of cifar10 images
# x: (batch_size, 3, 32, 32)
# patches: (batch_size, 3, 8, 8, 16)
patches
=
x
.
unfold
(
2
,
8
,
8
).
unfold
(
3
,
8
,
8
)
def
get_patches
(
self
,
x
,
patch_size
=
8
):
# get the patches
patches
=
x
.
unfold
(
2
,
patch_size
,
patch_size
).
unfold
(
3
,
patch_size
,
patch_size
)
return
patches
def
get_pos_encoding
(
self
,
d_emb
,
max_len
):
pos
=
torch
.
arange
(
0
,
max_len
).
float
().
unsqueeze
(
1
)
i
=
torch
.
arange
(
0
,
d_emb
,
2
).
float
()
div
=
torch
.
exp
(
-
i
*
math
.
log
(
10000
)
/
d_emb
)
sin
=
torch
.
sin
(
pos
*
div
)
cos
=
torch
.
cos
(
pos
*
div
)
pos_encoding
=
torch
.
cat
((
sin
,
cos
),
dim
=
1
).
view
(
1
,
max_len
,
d_emb
)
return
pos_encoding
def
forward
(
self
,
x
):
# get the patches
patches
=
self
.
get_patches
(
x
,
patch_size
=
self
.
patch_size
)
# flatten the patches
patches
=
patches
.
reshape
(
-
1
,
self
.
in_channels
,
self
.
patch_size
,
self
.
patch_size
)
# add extra embedding token if classification is needed
if
self
.
classify
:
self
.
extra_token
=
torch
.
rand
(
1
,
self
.
in_channels
,
self
.
patch_size
,
self
.
patch_size
)
self
.
extra_token
=
self
.
extra_token
.
to
(
x
.
device
)
# move extra_token to the same device as x
patches
=
torch
.
cat
((
self
.
extra_token
,
patches
),
dim
=
0
)
# get the embedding
embedding
=
self
.
conv
(
patches
)
# flatten the embedding
embedding
=
embedding
.
reshape
(
-
1
,
self
.
out_channels
)
# normalize the embedding
embedding
=
self
.
norm
(
embedding
)
# add the positional encoding
pos_encoding
=
self
.
get_pos_encoding
(
self
.
out_channels
,
embedding
.
shape
[
0
])
pos_encoding
=
pos_encoding
.
to
(
x
.
device
)
embedding
=
embedding
+
pos_encoding
if
self
.
return_patches
:
return
embedding
,
patches
else
:
return
embedding
# define transformer class
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
):
super
().
__init__
()
# Query, Key, Value weight matrices
self
.
qkv_linear
=
nn
.
Linear
(
embed_dim
,
embed_dim
*
3
)
# Final output weight matrix
self
.
output_linear
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
def
forward
(
self
,
x
):
batch_size
,
seq_len
,
embed_dim
=
x
.
size
()
# Create queries, keys, and values
qkv
=
self
.
qkv_linear
(
x
)
q
,
k
,
v
=
torch
.
split
(
qkv
,
embed_dim
,
dim
=-
1
)
# Compute attention scores
scores
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
2
,
-
1
))
/
(
embed_dim
**
0.5
)
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
)
# Apply attention to values
weighted_values
=
torch
.
matmul
(
attn
,
v
)
# Apply final output weight matrix
output
=
self
.
output_linear
(
weighted_values
)
return
output
from
torch.nn
import
TransformerEncoderLayer
class
PthBasedTransformer
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
embedding
=
Embedding
(
patch_size
=
8
,
in_channels
=
3
,
out_channels
=
8
,
return_patches
=
True
,
extra_token
=
True
)
self
.
self_attention
=
TransformerEncoderLayer
(
d_model
=
8
,
nhead
=
8
)
self
.
fc
=
nn
.
Linear
(
8
,
10
)
def
forward
(
self
,
x
):
return
self
.
_get_patches
(
x
)
embedding
,
patches
=
self
.
embedding
(
x
)
context
=
self
.
self_attention
(
embedding
)
# get the first token
context
=
context
[:,
0
,
:]
# get the classification
context
=
self
.
fc
(
context
)
return
context
class
MyTransformer
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MyTransformer
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
6
,
5
)
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
fc1
=
nn
.
Linear
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
self
.
embedding
=
Embedding
(
patch_size
=
8
,
in_channels
=
3
,
out_channels
=
64
)
self
.
attention
=
SelfAttention
(
embed_dim
=
64
)
self
.
fc1
=
nn
.
Linear
(
64
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
pool
(
F
.
relu
(
self
.
conv1
(
x
)))
x
=
self
.
pool
(
F
.
relu
(
self
.
conv2
(
x
)))
x
=
x
.
view
(
-
1
,
16
*
5
*
5
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc2
(
x
))
x
=
self
.
fc3
(
x
)
x
=
self
.
embedding
(
x
)
x
=
self
.
attention
(
x
)
x
=
self
.
fc1
(
x
)
return
x
...
...
@@ -123,6 +222,7 @@ def main(train_model=False, model_type="resnet"):
# use cuda if available
use_cuda
=
torch
.
cuda
.
is_available
()
device
=
torch
.
device
(
"
cuda
"
if
use_cuda
else
"
cpu
"
)
# device = torch.device("cpu")
if
model_type
==
"
resnet
"
:
...
...
@@ -131,6 +231,12 @@ def main(train_model=False, model_type="resnet"):
elif
model_type
==
"
cnn
"
:
# get the cnn model
model
=
MyModel
().
to
(
device
)
elif
model_type
==
"
transformer
"
:
# get the transformer model
model
=
MyTransformer
().
to
(
device
)
elif
model_type
==
"
pth_transformer
"
:
# get the transformer model
model
=
PthBasedTransformer
().
to
(
device
)
if
not
train_model
:
# check is model exists
...
...
@@ -153,8 +259,11 @@ def main(train_model=False, model_type="resnet"):
torch
.
save
(
model
.
state_dict
(),
"
cifar_cnn.pt
"
)
if
__name__
==
'
__main__
'
:
train_model
=
Fals
e
train_model
=
Tru
e
model_type
=
"
cnn
"
# model_type = "resnet"
model_type
=
"
pth_transformer
"
main
(
train_model
=
train_model
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment