Skip to content
Snippets Groups Projects
Commit 2406277f authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

progress towards own model (cifar experiments)

parent 9a22434c
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import torch import torch
# get cifar10 data # get cifar10 data
# import cifar10 dataset # import cifar10 dataset
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
# import torchvision transforms # import torchvision transforms
from torchvision import transforms from torchvision import transforms
# set a seed # set a seed
torch.manual_seed(0) torch.manual_seed(0)
# get the train data # get the train data
# define the transform # define the transform
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip() transforms.RandomVerticalFlip()
]) ])
train_data = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor()) train_data = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
# apply the transform # apply the transform
# train_data = CIFAR10(root='./data', train=True, download=True, transform=transform) # train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
# get the first image # get the first image
img, label = train_data[0] img, label = train_data[0]
# show the image # show the image
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
plt.imshow(img.permute(1, 2, 0)) plt.imshow(img.permute(1, 2, 0))
plt.show() plt.show()
``` ```
%% Output %% Output
Files already downloaded and verified Files already downloaded and verified
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from torch import nn from torch import nn
patch_size = 4 patch_size = 4
def get_patches(x, patch_size=8): def get_patches(x, patch_size=8):
# get the batch size # get the batch size
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size) patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
return patches return patches
# run the function on the first image # run the function on the first image
patches = get_patches(img.unsqueeze(0), patch_size=patch_size) patches = get_patches(img.unsqueeze(0), patch_size=patch_size)
# show the patches using patches.shape # show the patches using patches.shape
no = int(32 / patch_size) no = int(32 / patch_size)
print('no: ', no) print('no: ', no)
fig, ax = plt.subplots(no, no) fig, ax = plt.subplots(no, no)
for i in range(no): for i in range(no):
for j in range(no): for j in range(no):
ax[i, j].imshow(patches[0, :, i, j, :].permute(1, 2, 0)) ax[i, j].imshow(patches[0, :, i, j, :].permute(1, 2, 0))
ax[i, j].axis('off') ax[i, j].axis('off')
print('patches 10: ', patches[0, :, 0, 0, :].permute(1, 2, 0)) print('patches 10: ', patches[0, :, 0, 0, :].permute(1, 2, 0))
``` ```
%% Output %% Output
no: 8 no: 8
patches 10: tensor([[[0.2314, 0.2431, 0.2471], patches 10: tensor([[[0.2314, 0.2431, 0.2471],
[0.1686, 0.1804, 0.1765], [0.1686, 0.1804, 0.1765],
[0.1961, 0.1882, 0.1686], [0.1961, 0.1882, 0.1686],
[0.2667, 0.2118, 0.1647]], [0.2667, 0.2118, 0.1647]],
[[0.0627, 0.0784, 0.0784], [[0.0627, 0.0784, 0.0784],
[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
[0.0706, 0.0314, 0.0000], [0.0706, 0.0314, 0.0000],
[0.2000, 0.1059, 0.0314]], [0.2000, 0.1059, 0.0314]],
[[0.0980, 0.0941, 0.0824], [[0.0980, 0.0941, 0.0824],
[0.0627, 0.0275, 0.0000], [0.0627, 0.0275, 0.0000],
[0.1922, 0.1059, 0.0314], [0.1922, 0.1059, 0.0314],
[0.3255, 0.1961, 0.0902]], [0.3255, 0.1961, 0.0902]],
[[0.1294, 0.0980, 0.0667], [[0.1294, 0.0980, 0.0667],
[0.1490, 0.0784, 0.0157], [0.1490, 0.0784, 0.0157],
[0.3412, 0.2118, 0.0980], [0.3412, 0.2118, 0.0980],
[0.4157, 0.2471, 0.1098]]]) [0.4157, 0.2471, 0.1098]]])
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from torch import nn from torch import nn
# define embedding class # define embedding class
class Embedding(nn.Module): class Embedding(nn.Module):
def __init__(self, patch_size, in_channels, out_channels, return_patches=False): def __init__(self, patch_size, in_channels, out_channels, return_patches=False):
super(Embedding, self).__init__() super(Embedding, self).__init__()
self.patch_size = patch_size self.patch_size = patch_size
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.return_patches = return_patches self.return_patches = return_patches
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(out_channels) self.norm = nn.LayerNorm(out_channels)
def get_patches(self, x, patch_size=8): def get_patches(self, x, patch_size=8):
# get the patches # get the patches
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size) patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
return patches return patches
def forward(self, x): def forward(self, x):
# get the patches # get the patches
print('patch size: ', self.patch_size) print('patch size: ', self.patch_size)
patches = self.get_patches(x, patch_size=self.patch_size) patches = self.get_patches(x, patch_size=self.patch_size)
# flatten the patches # flatten the patches
patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size) patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)
# get the embedding # get the embedding
embedding = self.conv(patches) embedding = self.conv(patches)
# flatten the embedding # flatten the embedding
embedding = embedding.reshape(-1, self.out_channels) embedding = embedding.reshape(-1, self.out_channels)
# normalize the embedding # normalize the embedding
embedding = self.norm(embedding) embedding = self.norm(embedding)
if self.return_patches: if self.return_patches:
return embedding, patches return embedding, patches
else: else:
return embedding return embedding
# use the embedding class # use the embedding class
patch_size = 16 patch_size = 16
embedding = Embedding(patch_size=patch_size, in_channels=3, out_channels=8, return_patches=True) embedding = Embedding(patch_size=patch_size, in_channels=3, out_channels=8, return_patches=True)
embedding(img.unsqueeze(0)) embedding(img.unsqueeze(0))
# show the embedding and patches # show the embedding and patches
embedding, patches = embedding(img.unsqueeze(0)) embedding, patches = embedding(img.unsqueeze(0))
print('patches: ', patches.shape) print('patches: ', patches.shape)
print('embedding: ', embedding.shape) print('embedding: ', embedding.shape)
# plot the patches # plot the patches
no = int(32 / patch_size) no = int(32 / patch_size)
fig, ax = plt.subplots(no, no) fig, ax = plt.subplots(no, no)
for i in range(no): for i in range(no):
for j in range(no): for j in range(no):
ax[i, j].imshow(patches[i * no + j, :].permute(1, 2, 0)) ax[i, j].imshow(patches[i * no + j, :].permute(1, 2, 0))
ax[i, j].axis('off') ax[i, j].axis('off')
# plot the embeddings # plot the embeddings
no = int(32 / patch_size) no = int(32 / patch_size)
fig, ax = plt.subplots(no, no) fig, ax = plt.subplots(no, no)
for i in range(no): for i in range(no):
for j in range(no): for j in range(no):
ax[i, j].imshow(embedding[i * no + j, :].detach().numpy().reshape(1, -1)) ax[i, j].imshow(embedding[i * no + j, :].detach().numpy().reshape(1, -1))
ax[i, j].axis('off') ax[i, j].axis('off')
``` ```
%% Output %% Output
patch size: 16 patch size: 16
patch size: 16 patch size: 16
patches: torch.Size([4, 3, 16, 16]) patches: torch.Size([4, 3, 16, 16])
embedding: torch.Size([4, 8]) embedding: torch.Size([4, 8])
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# create position embedding import math
class PositionEmbedding(nn.Module): import torch
def __init__(self, patch_size, in_channels, out_channels):
super(PositionEmbedding, self).__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(out_channels)
def get_patches(self, x, patch_size=8):
# get the patches
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
return patches
# define cosinusoidal position embedding def sinusoidal_encoding_table(n_position, d_hid, padding_idx=None):
def get_cosine_position_embedding(self, x, patch_size=8): '''Generate sinusoidal position encoding table'''
# get the patches encoding_table = torch.zeros(n_position, d_hid)
patches = self.get_patches(x, patch_size=patch_size) position = torch.arange(0, n_position).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_hid, 2) * -(math.log(10000.0) / d_hid))
encoding_table[:, 0::2] = torch.sin(position * div_term)
encoding_table[:, 1::2] = torch.cos(position * div_term)
if padding_idx is not None:
encoding_table[padding_idx] = 0.
return encoding_table
seq_len = 100
embedding_dim = 64
pos_encoding = sinusoidal_encoding_table(seq_len, embedding_dim)
print(pos_encoding)
# plot the position encoding
fig, ax = plt.subplots(1, 1)
ax.imshow(pos_encoding.detach().numpy())
ax.axis('off')
# print shape of patches ```
print('patches shape inside : ', patches.shape)
# get the batch size %% Output
batch_size = patches.shape[0]
# get the number of patches
no_patches = (32 / patch_size) ** 2
# get the patch size
patch_size = patch_size
# get the position embedding
pos_embedding = torch.arange(0, no_patches).unsqueeze(1) / (10000 ** (torch.arange(0, self.out_channels, 2) / self.out_channels))
# print self.out_channels
print('self.out_channels inside : ', self.out_channels)
# print shape
print('pos_embedding inside : ', pos_embedding.shape)
# get the sine and cosine embedding
pos_embedding = torch.cat([torch.sin(pos_embedding), torch.cos(pos_embedding)], dim=1).reshape(1, no_patches, self.out_channels).permute(0, 2, 1)
# expand the position embedding
pos_embedding = pos_embedding.expand(batch_size, -1, -1)
print('pos_embedding inside before return : ', pos_embedding.shape) tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,
0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 6.8156e-01, ..., 1.0000e+00,
1.3335e-04, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 9.9748e-01, ..., 1.0000e+00,
2.6670e-04, 1.0000e+00],
...,
[ 3.7961e-01, -9.2515e-01, -4.6453e-01, ..., 9.9985e-01,
1.2935e-02, 9.9992e-01],
[-5.7338e-01, -8.1929e-01, -9.4349e-01, ..., 9.9985e-01,
1.3068e-02, 9.9991e-01],
[-9.9921e-01, 3.9821e-02, -9.1628e-01, ..., 9.9985e-01,
1.3201e-02, 9.9991e-01]])
return pos_embedding (-0.5, 63.5, 99.5, -0.5)
def forward(self, x):
# get the patches
patches = self.get_patches(x, patch_size=self.patch_size)
# flatten the patches
patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)
# get the embedding
embedding = self.conv(patches)
# flatten the embedding
embedding = embedding.reshape(-1, self.out_channels)
# normalize the embedding
embedding = self.norm(embedding)
# add the position embedding
pos_embedding = self.get_cosine_position_embedding(x, patch_size=self.patch_size)
print('pos embedding: ', pos_embedding.shape) %% Cell type:code id: tags:
return embedding + pos_embedding
``` python
def get_pos_encoding(max_len, d_emb):
pos = torch.arange(0, max_len).float().unsqueeze(1)
i = torch.arange(0, d_emb, 2).float()
# use the position embedding div = torch.exp(-i * math.log(10000) / d_emb)
patch_size = 16
pos_embedding = PositionEmbedding(patch_size=patch_size, in_channels=3, out_channels=8)
pos_embedding(img.unsqueeze(0)) sin = torch.sin(pos * div)
cos = torch.cos(pos * div)
# show the embedding and patches pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)
embedding, patches = pos_embedding(img.unsqueeze(0))
print('patches: ', patches.shape) return pos_encoding
print('embedding: ', embedding.shape)
# # plot the patches seq_len = 100
# no = int(32 / patch_size) embedding_dim = 64
# fig, ax = plt.subplots(no, no)
# for i in range(no):
# for j in range(no):
# ax[i, j].imshow(patches[i * no + j, :].permute(1, 2, 0))
# ax[i, j].axis('off')
# # plot the embeddings
# no = int(32 / patch_size)
# fig, ax = plt.subplots(no, no)
# for i in range(no):
# for j in range(no):
# ax[i, j].imshow(embedding[i * no + j, :].detach().numpy().reshape(1, -1))
# ax[i, j].axis('off')
pos_encoding = get_pos_encoding(seq_len, embedding_dim)
# plot results of get_cosine_position_embedding print(pos_encoding)
pos_embedding = pos_embedding.get_cosine_position_embedding(img.unsqueeze(0), patch_size=patch_size)
print('pos_embedding: ', pos_embedding.shape)
# plot the embeddings
no = int(32 / patch_size)
fig, ax = plt.subplots(no, no)
for i in range(no):
for j in range(no):
ax[i, j].imshow(pos_embedding[0, :, i * no + j].detach().numpy().reshape(1, -1))
ax[i, j].axis('off')
# plot the position encoding
fig, ax = plt.subplots(1, 1)
ax.imshow(pos_encoding.squeeze().detach().numpy())
ax.axis('off')
``` ```
%% Output %% Output
patches shape inside : torch.Size([1, 3, 2, 2, 16, 16]) tensor([[[ 0.0000, 0.0000, 0.0000, ..., 1.0000, 1.0000, 1.0000],
pos_embedding inside : torch.Size([4, 4]) [ 0.8415, 0.6816, 0.5332, ..., 1.0000, 1.0000, 1.0000],
[ 0.9093, 0.9975, 0.9021, ..., 1.0000, 1.0000, 1.0000],
...,
[ 0.3796, -0.4645, -0.9086, ..., 0.9997, 0.9999, 0.9999],
[-0.5734, -0.9435, -0.9914, ..., 0.9997, 0.9998, 0.9999],
[-0.9992, -0.9163, -0.7687, ..., 0.9997, 0.9998, 0.9999]]])
(-0.5, 63.5, 99.5, -0.5)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[6], line 68
65 patch_size = 16
66 pos_embedding = PositionEmbedding(patch_size=patch_size, in_channels=3, out_channels=8)
---> 68 pos_embedding(img.unsqueeze(0))
70 # show the embedding and patches
71 embedding, patches = pos_embedding(img.unsqueeze(0))
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
Cell In[6], line 58, in PositionEmbedding.forward(self, x)
56 embedding = self.norm(embedding)
57 # add the position embedding
---> 58 pos_embedding = self.get_cosine_position_embedding(x, patch_size=self.patch_size)
60 print('pos embedding: ', pos_embedding.shape)
61 return embedding + pos_embedding
Cell In[6], line 38, in PositionEmbedding.get_cosine_position_embedding(self, x, patch_size)
35 print('pos_embedding inside : ', pos_embedding.shape)
37 # get the sine and cosine embedding
---> 38 pos_embedding = torch.cat([torch.sin(pos_embedding), torch.cos(pos_embedding)], dim=1).reshape(1, no_patches, self.out_channels).permute(0, 2, 1)
39 # expand the position embedding
40 pos_embedding = pos_embedding.expand(batch_size, -1, -1)
TypeError: reshape(): argument 'shape' must be tuple of SymInts, but found element of type float at pos 2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment