Skip to content
Snippets Groups Projects
Commit 3e965fac authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

transformer cifar works for a single image

parent 2406277f
Branches
No related tags found
No related merge requests found
# simple cifar10 example using resnet
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -20,10 +21,11 @@ train_data = CIFAR10(root='./data', train=True, download=True, transform=transfo
test_data = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
# get the train loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=True)
# get the test loader
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
# train the model
def train(model, device, train_loader, optimizer, epoch):
......@@ -32,7 +34,7 @@ def train(model, device, train_loader, optimizer, epoch):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
......@@ -82,38 +84,135 @@ class MyModel(nn.Module):
#### define my transformer model
class Embeddings(nn.Module):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
class Embedding(nn.Module):
def __init__(self, patch_size, in_channels, out_channels, device='cpu', return_patches=False, extra_token=False):
super(Embedding, self).__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.device = device
self.return_patches = return_patches
self.classify = extra_token
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(out_channels)
self.extra_token = None # initialize extra_token tensor
def _get_patches(self, x):
# get the patches of cifar10 images
# x: (batch_size, 3, 32, 32)
# patches: (batch_size, 3, 8, 8, 16)
patches = x.unfold(2, 8, 8).unfold(3, 8, 8)
def get_patches(self, x, patch_size=8):
# get the patches
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
return patches
def get_pos_encoding(self, d_emb, max_len):
pos = torch.arange(0, max_len).float().unsqueeze(1)
i = torch.arange(0, d_emb, 2).float()
div = torch.exp(-i * math.log(10000) / d_emb)
sin = torch.sin(pos * div)
cos = torch.cos(pos * div)
pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)
return pos_encoding
def forward(self, x):
# get the patches
patches = self.get_patches(x, patch_size=self.patch_size)
# flatten the patches
patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)
# add extra embedding token if classification is needed
if self.classify:
self.extra_token = torch.rand(1, self.in_channels, self.patch_size, self.patch_size)
self.extra_token = self.extra_token.to(x.device) # move extra_token to the same device as x
patches = torch.cat((self.extra_token, patches), dim=0)
# get the embedding
embedding = self.conv(patches)
# flatten the embedding
embedding = embedding.reshape(-1, self.out_channels)
# normalize the embedding
embedding = self.norm(embedding)
# add the positional encoding
pos_encoding = self.get_pos_encoding(self.out_channels, embedding.shape[0])
pos_encoding = pos_encoding.to(x.device)
embedding = embedding + pos_encoding
if self.return_patches:
return embedding, patches
else:
return embedding
# define transformer class
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
# Query, Key, Value weight matrices
self.qkv_linear = nn.Linear(embed_dim, embed_dim * 3)
# Final output weight matrix
self.output_linear = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.size()
# Create queries, keys, and values
qkv = self.qkv_linear(x)
q, k, v = torch.split(qkv, embed_dim, dim=-1)
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / (embed_dim ** 0.5)
attn = torch.softmax(scores, dim=-1)
# Apply attention to values
weighted_values = torch.matmul(attn, v)
# Apply final output weight matrix
output = self.output_linear(weighted_values)
return output
from torch.nn import TransformerEncoderLayer
class PthBasedTransformer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=8, return_patches=True, extra_token=True)
self.self_attention = TransformerEncoderLayer(d_model=8, nhead=8)
self.fc = nn.Linear(8, 10)
def forward(self, x):
return self._get_patches(x)
embedding, patches = self.embedding(x)
context = self.self_attention(embedding)
# get the first token
context = context[:, 0, :]
# get the classification
context = self.fc(context)
return context
class MyTransformer(nn.Module):
def __init__(self):
super(MyTransformer, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=64)
self.attention = SelfAttention(embed_dim=64)
self.fc1 = nn.Linear(64, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
x = self.embedding(x)
x = self.attention(x)
x = self.fc1(x)
return x
......@@ -123,6 +222,7 @@ def main(train_model=False, model_type="resnet"):
# use cuda if available
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# device = torch.device("cpu")
if model_type == "resnet":
......@@ -131,6 +231,12 @@ def main(train_model=False, model_type="resnet"):
elif model_type == "cnn":
# get the cnn model
model = MyModel().to(device)
elif model_type == "transformer":
# get the transformer model
model = MyTransformer().to(device)
elif model_type == "pth_transformer":
# get the transformer model
model = PthBasedTransformer().to(device)
if not train_model:
# check is model exists
......@@ -153,8 +259,11 @@ def main(train_model=False, model_type="resnet"):
torch.save(model.state_dict(), "cifar_cnn.pt")
if __name__ == '__main__':
train_model = False
train_model = True
model_type = "cnn"
# model_type = "resnet"
model_type = "pth_transformer"
main(
train_model=train_model,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment