Skip to content
Snippets Groups Projects
Commit 4febaf36 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

transfomer for cifar10 works - problems with class embedding

parent 3e965fac
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,8 @@ import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import wandb
# import resnet18 from trochvision
from torchvision.models import resnet18
......@@ -21,10 +23,10 @@ train_data = CIFAR10(root='./data', train=True, download=True, transform=transfo
test_data = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
# get the train loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
# get the test loader
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
# train the model
......@@ -42,6 +44,19 @@ def train(model, device, train_loader, optimizer, epoch):
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
# log to wandb
wandb.log({"loss": loss.item()})
wandb.log({"epoch": epoch})
# compute the accuracy
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct = pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / len(data)
# log to wandb
wandb.log({"accuracy": accuracy})
# test the model
def test(model, device, test_loader):
model.eval()
......@@ -84,22 +99,22 @@ class MyModel(nn.Module):
#### define my transformer model
# define embedding class
class Embedding(nn.Module):
def __init__(self, patch_size, in_channels, out_channels, device='cpu', return_patches=False, extra_token=False):
def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):
super(Embedding, self).__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.device = device
self.return_patches = return_patches
self.classify = extra_token
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(out_channels)
self.extra_token = None # initialize extra_token tensor
def get_patches(self, x, patch_size=8):
# get the patches
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size).to(x.device)
return patches
......@@ -121,12 +136,6 @@ class Embedding(nn.Module):
patches = self.get_patches(x, patch_size=self.patch_size)
# flatten the patches
patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)
# add extra embedding token if classification is needed
if self.classify:
self.extra_token = torch.rand(1, self.in_channels, self.patch_size, self.patch_size)
self.extra_token = self.extra_token.to(x.device) # move extra_token to the same device as x
patches = torch.cat((self.extra_token, patches), dim=0)
# get the embedding
embedding = self.conv(patches)
# flatten the embedding
......@@ -135,14 +144,22 @@ class Embedding(nn.Module):
embedding = self.norm(embedding)
# add the positional encoding
pos_encoding = self.get_pos_encoding(self.out_channels, embedding.shape[0])
pos_encoding = pos_encoding.to(x.device)
embedding = embedding + pos_encoding
embedding = embedding + pos_encoding.to(x.device)
# reshape the embedding
embedding = embedding.reshape(x.shape[0], -1, self.out_channels)
if self.classify:
# add the classification token
classification_token = torch.rand(x.shape[0], 1, self.out_channels).to(x.device)
embedding = torch.cat((classification_token, embedding), dim=1)
if self.return_patches:
return embedding, patches
else:
return embedding
# define transformer class
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
......@@ -178,26 +195,28 @@ from torch.nn import TransformerEncoderLayer
class PthBasedTransformer(nn.Module):
def __init__(self) -> None:
def __init__(self, embedding_size=64) -> None:
super().__init__()
self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=8, return_patches=True, extra_token=True)
self.self_attention = TransformerEncoderLayer(d_model=8, nhead=8)
self.fc = nn.Linear(8, 10)
self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True)
self.self_attention = TransformerEncoderLayer(
d_model=embedding_size,
nhead=16,
dim_feedforward=embedding_size*4,
dropout=0.3
)
self.fc = nn.Linear(embedding_size, 10)
def forward(self, x):
embedding, patches = self.embedding(x)
context = self.self_attention(embedding)
# get the first token
context = context[:, 0, :]
# context = context.mean(dim=1)
# get the classification
context = self.fc(context)
return context
......@@ -260,10 +279,29 @@ def main(train_model=False, model_type="resnet"):
if __name__ == '__main__':
train_model = True
model_type = "cnn"
# model_type = "cnn"
# model_type = "resnet"
model_type = "pth_transformer"
# Create a config object for wandb
config = {
'model_type': model_type,
'batch_size': 64,
'test_batch_size': 1000,
'epochs': 10,
'lr': 0.01,
'momentum': 0.5,
'no_cuda': False,
'seed': 1,
'log_interval': 10,
'patch_size': 8,
'in_channels': 3,
'out_channels': 64,
'extra_token': True
}
# add model type to wandb
wandb.init(project="cifar10_example_transformer", entity="maciej-wielgosz-nibio", config=config)
main(
train_model=train_model,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment