Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
point-transformer
Manage
Activity
Members
Plan
Wiki
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Model registry
Analyze
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Maciej Wielgosz
point-transformer
Commits
4febaf36
Commit
4febaf36
authored
2 years ago
by
Maciej Wielgosz
Browse files
Options
Downloads
Patches
Plain Diff
transfomer for cifar10 works - problems with class embedding
parent
3e965fac
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
cifar_example/cifar_example_transformer.py
+63
-25
63 additions, 25 deletions
cifar_example/cifar_example_transformer.py
with
63 additions
and
25 deletions
cifar_example/cifar_example_transformer.py
+
63
−
25
View file @
4febaf36
...
@@ -8,6 +8,8 @@ import torch.nn.functional as F
...
@@ -8,6 +8,8 @@ import torch.nn.functional as F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torchvision
import
torchvision
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
import
wandb
# import resnet18 from trochvision
# import resnet18 from trochvision
from
torchvision.models
import
resnet18
from
torchvision.models
import
resnet18
...
@@ -21,10 +23,10 @@ train_data = CIFAR10(root='./data', train=True, download=True, transform=transfo
...
@@ -21,10 +23,10 @@ train_data = CIFAR10(root='./data', train=True, download=True, transform=transfo
test_data
=
CIFAR10
(
root
=
'
./data
'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
ToTensor
())
test_data
=
CIFAR10
(
root
=
'
./data
'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
ToTensor
())
# get the train loader
# get the train loader
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_data
,
batch_size
=
1
,
shuffle
=
True
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_data
,
batch_size
=
32
,
shuffle
=
True
)
# get the test loader
# get the test loader
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_data
,
batch_size
=
1
,
shuffle
=
False
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_data
,
batch_size
=
32
,
shuffle
=
False
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
# train the model
# train the model
...
@@ -42,6 +44,19 @@ def train(model, device, train_loader, optimizer, epoch):
...
@@ -42,6 +44,19 @@ def train(model, device, train_loader, optimizer, epoch):
epoch
,
batch_idx
*
len
(
data
),
len
(
train_loader
.
dataset
),
epoch
,
batch_idx
*
len
(
data
),
len
(
train_loader
.
dataset
),
100.
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
100.
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
# log to wandb
wandb
.
log
({
"
loss
"
:
loss
.
item
()})
wandb
.
log
({
"
epoch
"
:
epoch
})
# compute the accuracy
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
# get the index of the max log-probability
correct
=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
accuracy
=
correct
/
len
(
data
)
# log to wandb
wandb
.
log
({
"
accuracy
"
:
accuracy
})
# test the model
# test the model
def
test
(
model
,
device
,
test_loader
):
def
test
(
model
,
device
,
test_loader
):
model
.
eval
()
model
.
eval
()
...
@@ -84,22 +99,22 @@ class MyModel(nn.Module):
...
@@ -84,22 +99,22 @@ class MyModel(nn.Module):
#### define my transformer model
#### define my transformer model
# define embedding class
class
Embedding
(
nn
.
Module
):
class
Embedding
(
nn
.
Module
):
def
__init__
(
self
,
patch_size
,
in_channels
,
out_channels
,
device
=
'
cpu
'
,
return_patches
=
False
,
extra_token
=
False
):
def
__init__
(
self
,
patch_size
,
in_channels
,
out_channels
,
return_patches
=
False
,
extra_token
=
False
):
super
(
Embedding
,
self
).
__init__
()
super
(
Embedding
,
self
).
__init__
()
self
.
patch_size
=
patch_size
self
.
patch_size
=
patch_size
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
device
=
device
self
.
return_patches
=
return_patches
self
.
return_patches
=
return_patches
self
.
classify
=
extra_token
self
.
classify
=
extra_token
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
norm
=
nn
.
LayerNorm
(
out_channels
)
self
.
norm
=
nn
.
LayerNorm
(
out_channels
)
self
.
extra_token
=
None
# initialize extra_token tensor
def
get_patches
(
self
,
x
,
patch_size
=
8
):
def
get_patches
(
self
,
x
,
patch_size
=
8
):
# get the patches
# get the patches
patches
=
x
.
unfold
(
2
,
patch_size
,
patch_size
).
unfold
(
3
,
patch_size
,
patch_size
)
patches
=
x
.
unfold
(
2
,
patch_size
,
patch_size
).
unfold
(
3
,
patch_size
,
patch_size
)
.
to
(
x
.
device
)
return
patches
return
patches
...
@@ -121,12 +136,6 @@ class Embedding(nn.Module):
...
@@ -121,12 +136,6 @@ class Embedding(nn.Module):
patches
=
self
.
get_patches
(
x
,
patch_size
=
self
.
patch_size
)
patches
=
self
.
get_patches
(
x
,
patch_size
=
self
.
patch_size
)
# flatten the patches
# flatten the patches
patches
=
patches
.
reshape
(
-
1
,
self
.
in_channels
,
self
.
patch_size
,
self
.
patch_size
)
patches
=
patches
.
reshape
(
-
1
,
self
.
in_channels
,
self
.
patch_size
,
self
.
patch_size
)
# add extra embedding token if classification is needed
if
self
.
classify
:
self
.
extra_token
=
torch
.
rand
(
1
,
self
.
in_channels
,
self
.
patch_size
,
self
.
patch_size
)
self
.
extra_token
=
self
.
extra_token
.
to
(
x
.
device
)
# move extra_token to the same device as x
patches
=
torch
.
cat
((
self
.
extra_token
,
patches
),
dim
=
0
)
# get the embedding
# get the embedding
embedding
=
self
.
conv
(
patches
)
embedding
=
self
.
conv
(
patches
)
# flatten the embedding
# flatten the embedding
...
@@ -135,14 +144,22 @@ class Embedding(nn.Module):
...
@@ -135,14 +144,22 @@ class Embedding(nn.Module):
embedding
=
self
.
norm
(
embedding
)
embedding
=
self
.
norm
(
embedding
)
# add the positional encoding
# add the positional encoding
pos_encoding
=
self
.
get_pos_encoding
(
self
.
out_channels
,
embedding
.
shape
[
0
])
pos_encoding
=
self
.
get_pos_encoding
(
self
.
out_channels
,
embedding
.
shape
[
0
])
pos_encoding
=
pos_encoding
.
to
(
x
.
device
)
embedding
=
embedding
+
pos_encoding
.
to
(
x
.
device
)
embedding
=
embedding
+
pos_encoding
# reshape the embedding
embedding
=
embedding
.
reshape
(
x
.
shape
[
0
],
-
1
,
self
.
out_channels
)
if
self
.
classify
:
# add the classification token
classification_token
=
torch
.
rand
(
x
.
shape
[
0
],
1
,
self
.
out_channels
).
to
(
x
.
device
)
embedding
=
torch
.
cat
((
classification_token
,
embedding
),
dim
=
1
)
if
self
.
return_patches
:
if
self
.
return_patches
:
return
embedding
,
patches
return
embedding
,
patches
else
:
else
:
return
embedding
return
embedding
# define transformer class
# define transformer class
class
SelfAttention
(
nn
.
Module
):
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
):
def
__init__
(
self
,
embed_dim
):
...
@@ -178,26 +195,28 @@ from torch.nn import TransformerEncoderLayer
...
@@ -178,26 +195,28 @@ from torch.nn import TransformerEncoderLayer
class
PthBasedTransformer
(
nn
.
Module
):
class
PthBasedTransformer
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
embedding_size
=
64
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
embedding
=
Embedding
(
patch_size
=
8
,
in_channels
=
3
,
out_channels
=
8
,
return_patches
=
True
,
extra_token
=
True
)
self
.
embedding
=
Embedding
(
patch_size
=
16
,
in_channels
=
3
,
out_channels
=
embedding_size
,
return_patches
=
True
,
extra_token
=
True
)
self
.
self_attention
=
TransformerEncoderLayer
(
d_model
=
8
,
nhead
=
8
)
self
.
self_attention
=
TransformerEncoderLayer
(
self
.
fc
=
nn
.
Linear
(
8
,
10
)
d_model
=
embedding_size
,
nhead
=
16
,
dim_feedforward
=
embedding_size
*
4
,
dropout
=
0.3
)
self
.
fc
=
nn
.
Linear
(
embedding_size
,
10
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
embedding
,
patches
=
self
.
embedding
(
x
)
embedding
,
patches
=
self
.
embedding
(
x
)
context
=
self
.
self_attention
(
embedding
)
context
=
self
.
self_attention
(
embedding
)
# get the first token
# get the first token
context
=
context
[:,
0
,
:]
context
=
context
[:,
0
,
:]
# context = context.mean(dim=1)
# get the classification
# get the classification
context
=
self
.
fc
(
context
)
context
=
self
.
fc
(
context
)
return
context
return
context
...
@@ -260,10 +279,29 @@ def main(train_model=False, model_type="resnet"):
...
@@ -260,10 +279,29 @@ def main(train_model=False, model_type="resnet"):
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
train_model
=
True
train_model
=
True
model_type
=
"
cnn
"
#
model_type = "cnn"
# model_type = "resnet"
# model_type = "resnet"
model_type
=
"
pth_transformer
"
model_type
=
"
pth_transformer
"
# Create a config object for wandb
config
=
{
'
model_type
'
:
model_type
,
'
batch_size
'
:
64
,
'
test_batch_size
'
:
1000
,
'
epochs
'
:
10
,
'
lr
'
:
0.01
,
'
momentum
'
:
0.5
,
'
no_cuda
'
:
False
,
'
seed
'
:
1
,
'
log_interval
'
:
10
,
'
patch_size
'
:
8
,
'
in_channels
'
:
3
,
'
out_channels
'
:
64
,
'
extra_token
'
:
True
}
# add model type to wandb
wandb
.
init
(
project
=
"
cifar10_example_transformer
"
,
entity
=
"
maciej-wielgosz-nibio
"
,
config
=
config
)
main
(
main
(
train_model
=
train_model
,
train_model
=
train_model
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment