Skip to content
Snippets Groups Projects
Commit 2406277f authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

progress towards own model (cifar experiments)

parent 9a22434c
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import torch
# get cifar10 data
# import cifar10 dataset
from torchvision.datasets import CIFAR10
# import torchvision transforms
from torchvision import transforms
# set a seed
torch.manual_seed(0)
# get the train data
# define the transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()
])
train_data = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
# apply the transform
# train_data = CIFAR10(root='./data', train=True, download=True, transform=transform)
# get the first image
img, label = train_data[0]
# show the image
import matplotlib.pyplot as plt
plt.imshow(img.permute(1, 2, 0))
plt.show()
```
%% Output
Files already downloaded and verified
%% Cell type:code id: tags:
``` python
from torch import nn
patch_size = 4
def get_patches(x, patch_size=8):
# get the batch size
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
return patches
# run the function on the first image
patches = get_patches(img.unsqueeze(0), patch_size=patch_size)
# show the patches using patches.shape
no = int(32 / patch_size)
print('no: ', no)
fig, ax = plt.subplots(no, no)
for i in range(no):
for j in range(no):
ax[i, j].imshow(patches[0, :, i, j, :].permute(1, 2, 0))
ax[i, j].axis('off')
print('patches 10: ', patches[0, :, 0, 0, :].permute(1, 2, 0))
```
%% Output
no: 8
patches 10: tensor([[[0.2314, 0.2431, 0.2471],
[0.1686, 0.1804, 0.1765],
[0.1961, 0.1882, 0.1686],
[0.2667, 0.2118, 0.1647]],
[[0.0627, 0.0784, 0.0784],
[0.0000, 0.0000, 0.0000],
[0.0706, 0.0314, 0.0000],
[0.2000, 0.1059, 0.0314]],
[[0.0980, 0.0941, 0.0824],
[0.0627, 0.0275, 0.0000],
[0.1922, 0.1059, 0.0314],
[0.3255, 0.1961, 0.0902]],
[[0.1294, 0.0980, 0.0667],
[0.1490, 0.0784, 0.0157],
[0.3412, 0.2118, 0.0980],
[0.4157, 0.2471, 0.1098]]])
%% Cell type:code id: tags:
``` python
from torch import nn
# define embedding class
class Embedding(nn.Module):
def __init__(self, patch_size, in_channels, out_channels, return_patches=False):
super(Embedding, self).__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.return_patches = return_patches
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(out_channels)
def get_patches(self, x, patch_size=8):
# get the patches
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
return patches
def forward(self, x):
# get the patches
print('patch size: ', self.patch_size)
patches = self.get_patches(x, patch_size=self.patch_size)
# flatten the patches
patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)
# get the embedding
embedding = self.conv(patches)
# flatten the embedding
embedding = embedding.reshape(-1, self.out_channels)
# normalize the embedding
embedding = self.norm(embedding)
if self.return_patches:
return embedding, patches
else:
return embedding
# use the embedding class
patch_size = 16
embedding = Embedding(patch_size=patch_size, in_channels=3, out_channels=8, return_patches=True)
embedding(img.unsqueeze(0))
# show the embedding and patches
embedding, patches = embedding(img.unsqueeze(0))
print('patches: ', patches.shape)
print('embedding: ', embedding.shape)
# plot the patches
no = int(32 / patch_size)
fig, ax = plt.subplots(no, no)
for i in range(no):
for j in range(no):
ax[i, j].imshow(patches[i * no + j, :].permute(1, 2, 0))
ax[i, j].axis('off')
# plot the embeddings
no = int(32 / patch_size)
fig, ax = plt.subplots(no, no)
for i in range(no):
for j in range(no):
ax[i, j].imshow(embedding[i * no + j, :].detach().numpy().reshape(1, -1))
ax[i, j].axis('off')
```
%% Output
patch size: 16
patch size: 16
patches: torch.Size([4, 3, 16, 16])
embedding: torch.Size([4, 8])
%% Cell type:code id: tags:
``` python
# create position embedding
class PositionEmbedding(nn.Module):
def __init__(self, patch_size, in_channels, out_channels):
super(PositionEmbedding, self).__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(out_channels)
import math
import torch
def get_patches(self, x, patch_size=8):
# get the patches
patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
return patches
# define cosinusoidal position embedding
def get_cosine_position_embedding(self, x, patch_size=8):
# get the patches
patches = self.get_patches(x, patch_size=patch_size)
def sinusoidal_encoding_table(n_position, d_hid, padding_idx=None):
'''Generate sinusoidal position encoding table'''
encoding_table = torch.zeros(n_position, d_hid)
position = torch.arange(0, n_position).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_hid, 2) * -(math.log(10000.0) / d_hid))
encoding_table[:, 0::2] = torch.sin(position * div_term)
encoding_table[:, 1::2] = torch.cos(position * div_term)
if padding_idx is not None:
encoding_table[padding_idx] = 0.
return encoding_table
seq_len = 100
embedding_dim = 64
pos_encoding = sinusoidal_encoding_table(seq_len, embedding_dim)
print(pos_encoding)
# plot the position encoding
fig, ax = plt.subplots(1, 1)
ax.imshow(pos_encoding.detach().numpy())
ax.axis('off')
# print shape of patches
print('patches shape inside : ', patches.shape)
```
# get the batch size
batch_size = patches.shape[0]
# get the number of patches
no_patches = (32 / patch_size) ** 2
# get the patch size
patch_size = patch_size
# get the position embedding
pos_embedding = torch.arange(0, no_patches).unsqueeze(1) / (10000 ** (torch.arange(0, self.out_channels, 2) / self.out_channels))
# print self.out_channels
print('self.out_channels inside : ', self.out_channels)
# print shape
print('pos_embedding inside : ', pos_embedding.shape)
# get the sine and cosine embedding
pos_embedding = torch.cat([torch.sin(pos_embedding), torch.cos(pos_embedding)], dim=1).reshape(1, no_patches, self.out_channels).permute(0, 2, 1)
# expand the position embedding
pos_embedding = pos_embedding.expand(batch_size, -1, -1)
%% Output
print('pos_embedding inside before return : ', pos_embedding.shape)
tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,
0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 6.8156e-01, ..., 1.0000e+00,
1.3335e-04, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 9.9748e-01, ..., 1.0000e+00,
2.6670e-04, 1.0000e+00],
...,
[ 3.7961e-01, -9.2515e-01, -4.6453e-01, ..., 9.9985e-01,
1.2935e-02, 9.9992e-01],
[-5.7338e-01, -8.1929e-01, -9.4349e-01, ..., 9.9985e-01,
1.3068e-02, 9.9991e-01],
[-9.9921e-01, 3.9821e-02, -9.1628e-01, ..., 9.9985e-01,
1.3201e-02, 9.9991e-01]])
return pos_embedding
(-0.5, 63.5, 99.5, -0.5)
def forward(self, x):
# get the patches
patches = self.get_patches(x, patch_size=self.patch_size)
# flatten the patches
patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)
# get the embedding
embedding = self.conv(patches)
# flatten the embedding
embedding = embedding.reshape(-1, self.out_channels)
# normalize the embedding
embedding = self.norm(embedding)
# add the position embedding
pos_embedding = self.get_cosine_position_embedding(x, patch_size=self.patch_size)
print('pos embedding: ', pos_embedding.shape)
return embedding + pos_embedding
%% Cell type:code id: tags:
``` python
def get_pos_encoding(max_len, d_emb):
pos = torch.arange(0, max_len).float().unsqueeze(1)
i = torch.arange(0, d_emb, 2).float()
# use the position embedding
patch_size = 16
pos_embedding = PositionEmbedding(patch_size=patch_size, in_channels=3, out_channels=8)
div = torch.exp(-i * math.log(10000) / d_emb)
pos_embedding(img.unsqueeze(0))
sin = torch.sin(pos * div)
cos = torch.cos(pos * div)
# show the embedding and patches
embedding, patches = pos_embedding(img.unsqueeze(0))
pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)
print('patches: ', patches.shape)
print('embedding: ', embedding.shape)
return pos_encoding
# # plot the patches
# no = int(32 / patch_size)
# fig, ax = plt.subplots(no, no)
# for i in range(no):
# for j in range(no):
# ax[i, j].imshow(patches[i * no + j, :].permute(1, 2, 0))
# ax[i, j].axis('off')
# # plot the embeddings
# no = int(32 / patch_size)
# fig, ax = plt.subplots(no, no)
# for i in range(no):
# for j in range(no):
# ax[i, j].imshow(embedding[i * no + j, :].detach().numpy().reshape(1, -1))
# ax[i, j].axis('off')
seq_len = 100
embedding_dim = 64
pos_encoding = get_pos_encoding(seq_len, embedding_dim)
# plot results of get_cosine_position_embedding
pos_embedding = pos_embedding.get_cosine_position_embedding(img.unsqueeze(0), patch_size=patch_size)
print('pos_embedding: ', pos_embedding.shape)
# plot the embeddings
no = int(32 / patch_size)
fig, ax = plt.subplots(no, no)
for i in range(no):
for j in range(no):
ax[i, j].imshow(pos_embedding[0, :, i * no + j].detach().numpy().reshape(1, -1))
ax[i, j].axis('off')
print(pos_encoding)
# plot the position encoding
fig, ax = plt.subplots(1, 1)
ax.imshow(pos_encoding.squeeze().detach().numpy())
ax.axis('off')
```
%% Output
patches shape inside : torch.Size([1, 3, 2, 2, 16, 16])
pos_embedding inside : torch.Size([4, 4])
tensor([[[ 0.0000, 0.0000, 0.0000, ..., 1.0000, 1.0000, 1.0000],
[ 0.8415, 0.6816, 0.5332, ..., 1.0000, 1.0000, 1.0000],
[ 0.9093, 0.9975, 0.9021, ..., 1.0000, 1.0000, 1.0000],
...,
[ 0.3796, -0.4645, -0.9086, ..., 0.9997, 0.9999, 0.9999],
[-0.5734, -0.9435, -0.9914, ..., 0.9997, 0.9998, 0.9999],
[-0.9992, -0.9163, -0.7687, ..., 0.9997, 0.9998, 0.9999]]])
(-0.5, 63.5, 99.5, -0.5)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[6], line 68
65 patch_size = 16
66 pos_embedding = PositionEmbedding(patch_size=patch_size, in_channels=3, out_channels=8)
---> 68 pos_embedding(img.unsqueeze(0))
70 # show the embedding and patches
71 embedding, patches = pos_embedding(img.unsqueeze(0))
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
Cell In[6], line 58, in PositionEmbedding.forward(self, x)
56 embedding = self.norm(embedding)
57 # add the position embedding
---> 58 pos_embedding = self.get_cosine_position_embedding(x, patch_size=self.patch_size)
60 print('pos embedding: ', pos_embedding.shape)
61 return embedding + pos_embedding
Cell In[6], line 38, in PositionEmbedding.get_cosine_position_embedding(self, x, patch_size)
35 print('pos_embedding inside : ', pos_embedding.shape)
37 # get the sine and cosine embedding
---> 38 pos_embedding = torch.cat([torch.sin(pos_embedding), torch.cos(pos_embedding)], dim=1).reshape(1, no_patches, self.out_channels).permute(0, 2, 1)
39 # expand the position embedding
40 pos_embedding = pos_embedding.expand(batch_size, -1, -1)
TypeError: reshape(): argument 'shape' must be tuple of SymInts, but found element of type float at pos 2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment