diff --git a/cifar_example/cifar_example_transformer.py b/cifar_example/cifar_example_transformer.py index 45f47524c462b5d53f682db53bb1b9b4fab09def..f59818a97dde2841deabcc74b0c55618531c1bb2 100644 --- a/cifar_example/cifar_example_transformer.py +++ b/cifar_example/cifar_example_transformer.py @@ -14,6 +14,8 @@ import wandb # import resnet18 from trochvision from torchvision.models import resnet18 +from torch.nn import TransformerEncoderLayer + # import cifar10 dataset from torchvision.datasets import CIFAR10 @@ -44,17 +46,26 @@ def train(model, device, train_loader, optimizer, epoch): epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) - # log to wandb - wandb.log({"loss": loss.item()}) - wandb.log({"epoch": epoch}) - - # compute the accuracy - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct = pred.eq(target.view_as(pred)).sum().item() - accuracy = correct / len(data) + # log to wandb + wandb.log({"loss": loss.item()}) + wandb.log({"epoch": epoch}) + # get all the parameters of the model + params = list(model.named_parameters()) + # log the gradients + for name, param in params: + wandb.log({name + "_grad": wandb.Histogram(param.grad.cpu().numpy())}) + + # log the weights + for name, param in params: + wandb.log({name + "_weights": wandb.Histogram(param.detach().cpu().numpy())}) + + # compute the accuracy + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct = pred.eq(target.view_as(pred)).sum().item() + accuracy = correct / len(data) - # log to wandb - wandb.log({"accuracy": accuracy}) + # log to wandb + wandb.log({"accuracy": accuracy}) # test the model @@ -101,20 +112,23 @@ class MyModel(nn.Module): # define embedding class class Embedding(nn.Module): - def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False): + def __init__(self, patch_size, in_channels, out_channels, batch_size=1, return_patches=False, extra_token=False): super(Embedding, self).__init__() self.patch_size = patch_size self.in_channels = in_channels self.out_channels = out_channels self.return_patches = return_patches + self.class_embedding = nn.Parameter(torch.randn(1, out_channels)) self.classify = extra_token - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size) + self.patch_conv = nn.Conv2d( + in_channels, out_channels, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(out_channels) - + self.proj = nn.Linear(out_channels, out_channels) def get_patches(self, x, patch_size=8): # get the patches - patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size).to(x.device) + patches = x.unfold(2, patch_size, patch_size).unfold( + 3, patch_size, patch_size).to(x.device) return patches @@ -130,80 +144,117 @@ class Embedding(nn.Module): pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb) return pos_encoding - + def forward(self, x): - # get the patches - patches = self.get_patches(x, patch_size=self.patch_size) - # flatten the patches - patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size) - # get the embedding - embedding = self.conv(patches) - # flatten the embedding - embedding = embedding.reshape(-1, self.out_channels) - # normalize the embedding - embedding = self.norm(embedding) - # add the positional encoding - pos_encoding = self.get_pos_encoding(self.out_channels, embedding.shape[0]) - embedding = embedding + pos_encoding.to(x.device) - # reshape the embedding + embedding = self.patch_conv(x) + + # flatten the embedding embedding = embedding.reshape(x.shape[0], -1, self.out_channels) if self.classify: - # add the classification token - classification_token = torch.rand(x.shape[0], 1, self.out_channels).to(x.device) - embedding = torch.cat((classification_token, embedding), dim=1) - - if self.return_patches: - return embedding, patches - else: - return embedding + class_embedding = self.class_embedding.repeat(x.shape[0], 1, 1) + embedding = torch.cat([class_embedding, embedding], dim=1) + # normalize the embedding + embedding = self.norm(embedding) -# define transformer class -class SelfAttention(nn.Module): - def __init__(self, embed_dim): - super().__init__() + # project the embedding + embedding = self.proj(embedding) - # Query, Key, Value weight matrices - self.qkv_linear = nn.Linear(embed_dim, embed_dim * 3) + # add the positional encoding account for batch size + pos_encoding = self.get_pos_encoding( + self.out_channels, embedding.shape[1]).to(x.device) - # Final output weight matrix - self.output_linear = nn.Linear(embed_dim, embed_dim) - - def forward(self, x): - batch_size, seq_len, embed_dim = x.size() + embedding = embedding + pos_encoding - # Create queries, keys, and values - qkv = self.qkv_linear(x) - q, k, v = torch.split(qkv, embed_dim, dim=-1) - # Compute attention scores - scores = torch.matmul(q, k.transpose(-2, -1)) / (embed_dim ** 0.5) - attn = torch.softmax(scores, dim=-1) + if self.return_patches: + patches = self.get_patches(x, self.patch_size) + patches = patches.reshape( + x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size) - # Apply attention to values - weighted_values = torch.matmul(attn, v) + return embedding, patches + else: + return embedding - # Apply final output weight matrix - output = self.output_linear(weighted_values) +class SelfAttentionParam(nn.Module): + def __init__(self, in_features, out_features) -> None: + super(SelfAttentionParam, self).__init__() + self.query = nn.Linear(in_features, out_features) + self.key = nn.Linear(in_features, out_features) + self.value = nn.Linear(in_features, out_features) - return output + def forward(self, x): + batch_size, num_embeddings, embedding_dim = x.size() + Q = self.query(x).view(batch_size, num_embeddings, -1) + K = self.key(x).view(batch_size, num_embeddings, -1) + V = self.value(x).view(batch_size, num_embeddings, -1) + # Q, K, V = [batch_size, num_embeddings, embedding_dim] + energy = torch.bmm(Q, K.permute(0, 2, 1)) + # energy = [batch_size, num_embeddings, num_embeddings] + attention = torch.softmax(energy, dim=-1) + out = torch.bmm(attention, V) + out = out.view(batch_size, num_embeddings, -1) + return out + + +class MultiHeadAttention(nn.Module): + def __init__(self, embedd_size, heads=8) -> None: + super(MultiHeadAttention, self).__init__() + self.heads = heads + self.attention = nn.ModuleList([SelfAttentionParam(embedd_size, embedd_size) for _ in range(heads)]) + self.projection = nn.Linear(heads * embedd_size, embedd_size) + def forward(self, x): + out = [self.attention[i](x) for i in range(self.heads)] + out = torch.cat(out, dim=2) -from torch.nn import TransformerEncoderLayer + out = self.projection(out) + return out + + +class MyTransformerLayer(nn.Module): + def __init__(self, d_model, nhead, dropout=0.1, batch_first=False): + super(MyTransformerLayer, self).__init__() + # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) + print("d_model", d_model) + print("nhead", nhead) + + self.self_attn = MultiHeadAttention(embedd_size=d_model, heads=nhead) + self.linear1 = nn.Linear(d_model, d_model) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_model, d_model) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + # src2 = self.self_attn(src, src, src, attn_mask=src_mask, + # key_padding_mask=src_key_padding_mask)[0] + src2 = self.self_attn(src) + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src class PthBasedTransformer(nn.Module): def __init__(self, embedding_size=64) -> None: super().__init__() self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True) - self.self_attention = TransformerEncoderLayer( - d_model=embedding_size, - nhead=16, - dim_feedforward=embedding_size*4, - dropout=0.3 - ) + # self.self_attention = TransformerEncoderLayer( + # d_model=embedding_size, + # nhead=16, + # dim_feedforward=embedding_size*4, + # dropout=0.3, + # batch_first=True + # ) + + self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3, batch_first=True) self.fc = nn.Linear(embedding_size, 10) def forward(self, x): @@ -219,23 +270,6 @@ class PthBasedTransformer(nn.Module): return context - -class MyTransformer(nn.Module): - def __init__(self): - super(MyTransformer, self).__init__() - self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=64) - self.attention = SelfAttention(embed_dim=64) - self.fc1 = nn.Linear(64, 10) - - - def forward(self, x): - x = self.embedding(x) - x = self.attention(x) - x = self.fc1(x) - return x - - - # main function def main(train_model=False, model_type="resnet"): # use cuda if available @@ -250,9 +284,6 @@ def main(train_model=False, model_type="resnet"): elif model_type == "cnn": # get the cnn model model = MyModel().to(device) - elif model_type == "transformer": - # get the transformer model - model = MyTransformer().to(device) elif model_type == "pth_transformer": # get the transformer model model = PthBasedTransformer().to(device) @@ -283,6 +314,7 @@ if __name__ == '__main__': # model_type = "resnet" model_type = "pth_transformer" + # Create a config object for wandb config = { 'model_type': model_type, diff --git a/cifar_example/self_attention.ipynb b/cifar_example/self_attention.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a223d112a9797344c7c5b2934a9e370c50f0db99 --- /dev/null +++ b/cifar_example/self_attention.ipynb @@ -0,0 +1,495 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Simple self-attention implementation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import nn\n", + "import torch\n", + "\n", + "# create a transformer class model\n", + "from torch import nn\n", + "import torch\n", + "from torch.nn import functional as F\n", + "import math\n", + "\n", + "from torchvision.datasets import CIFAR10\n", + "# import torchvision transforms\n", + "from torchvision import transforms\n", + "\n", + "# create the same class but parameterized\n", + "class SelfAttentionParam(nn.Module):\n", + " def __init__(self, in_channels, out_channels, kernel_size=1) -> None:\n", + " super(SelfAttentionParam, self).__init__()\n", + " self.query = nn.Conv2d(\n", + " in_channels=in_channels, \n", + " out_channels=out_channels, \n", + " kernel_size=kernel_size\n", + " )\n", + " self.key = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)\n", + " self.value = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)\n", + "\n", + " def forward(self, x):\n", + " batch_size, C, H, W = x.size()\n", + " Q = self.query(x).view(batch_size, -1, H*W)\n", + " K = self.key(x).view(batch_size, -1, H*W)\n", + " V = self.value(x).view(batch_size, -1, H*W)\n", + " # Q, K, V = [batch_size, N, N]\n", + " energy = torch.bmm(Q.permute(0, 2, 1), K)\n", + " # energy = [batch_size, N, N]\n", + " attention = torch.softmax(energy, dim=-1)\n", + " out = torch.bmm(V, attention.permute(0, 2, 1))\n", + " out = out.view(batch_size, C, H, W)\n", + " return out\n", + "\n", + "# create multiple heads\n", + "class MultiHeadAttention(nn.Module):\n", + " def __init__(self, in_channels, out_channels, kernel_size=1, heads=8) -> None:\n", + " super(MultiHeadAttention, self).__init__()\n", + " self.heads = heads\n", + " self.attention = nn.ModuleList([SelfAttentionParam(in_channels, out_channels, kernel_size) for _ in range(heads)])\n", + "\n", + " def forward(self, x):\n", + " batch_size, C, H, W = x.size()\n", + " out = torch.zeros(batch_size, C, H, W).to(x.device)\n", + " for i in range(self.heads):\n", + " out += self.attention[i](x)\n", + " return out\n", + " \n", + "\n", + "# create the same class but assuming the input is already flattened (N, num_features, embedding_size)\n", + "class SelfAttentionParamFlattened(nn.Module):\n", + " def __init__(self, num_features, embedding_size) -> None:\n", + " super(SelfAttentionParamFlattened, self).__init__()\n", + " self.query = nn.Linear(num_features, embedding_size)\n", + " self.key = nn.Linear(num_features, embedding_size)\n", + " self.value = nn.Linear(num_features, embedding_size)\n", + "\n", + " def forward(self, x):\n", + " batch_size, N, E = x.size()\n", + " Q = self.query(x.view(batch_size*N, -1)).view(batch_size, N, -1)\n", + " K = self.key(x.view(batch_size*N, -1)).view(batch_size, N, -1)\n", + " V = self.value(x.view(batch_size*N, -1)).view(batch_size, N, -1)\n", + "\n", + " # Q, K, V = [batch_size, N, E]\n", + " energy = torch.bmm(Q, K.transpose(1,2)) / math.sqrt(E)\n", + " # energy = [batch_size, N, N]\n", + " attention = torch.softmax(energy, dim=-1)\n", + " out = torch.bmm(attention, V)\n", + " return out\n", + " \n", + "# create multiple heads\n", + "class MultiHeadAttentionFlattened(nn.Module):\n", + " def __init__(self, num_features, embedding_size, heads=8) -> None:\n", + " super(MultiHeadAttentionFlattened, self).__init__()\n", + " self.heads = heads\n", + " self.attention = nn.ModuleList([SelfAttentionParamFlattened(num_features, embedding_size) for _ in range(heads)])\n", + "\n", + " def forward(self, x):\n", + " batch_size, N, E = x.size()\n", + " out = torch.zeros(batch_size, N, E).to(x.device)\n", + " for i in range(self.heads):\n", + " out += self.attention[i](x)\n", + " return out\n", + "\n", + "\n", + "class Embedding(nn.Module):\n", + " def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):\n", + " super(Embedding, self).__init__()\n", + " self.patch_size = patch_size\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.return_patches = return_patches\n", + " self.class_embedding = nn.Parameter(torch.randn(1, out_channels))\n", + " self.classify = extra_token\n", + " self.patch_conv = nn.Conv2d(\n", + " in_channels, out_channels, kernel_size=patch_size, stride=patch_size)\n", + " self.norm = nn.LayerNorm(out_channels)\n", + " self.proj = nn.Linear(out_channels, out_channels)\n", + "\n", + " def get_patches(self, x, patch_size=8):\n", + " # get the patches\n", + " patches = x.unfold(2, patch_size, patch_size).unfold(\n", + " 3, patch_size, patch_size).to(x.device)\n", + "\n", + " return patches\n", + "\n", + " def get_pos_encoding(self, d_emb, max_len):\n", + " pos = torch.arange(0, max_len).float().unsqueeze(1)\n", + " i = torch.arange(0, d_emb, 2).float()\n", + "\n", + " div = torch.exp(-i * math.log(10000) / d_emb)\n", + "\n", + " sin = torch.sin(pos * div)\n", + " cos = torch.cos(pos * div)\n", + "\n", + " pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n", + "\n", + " return pos_encoding\n", + "\n", + " def forward(self, x):\n", + "\n", + " embedding = self.patch_conv(x)\n", + "\n", + " # flatten the embedding\n", + " embedding = embedding.reshape(x.shape[0], -1, self.out_channels)\n", + "\n", + " if self.classify:\n", + " class_embedding = self.class_embedding.repeat(x.shape[0], 1, 1)\n", + " embedding = torch.cat([class_embedding, embedding], dim=1)\n", + "\n", + " # normalize the embedding\n", + " embedding = self.norm(embedding)\n", + "\n", + " # project the embedding\n", + " embedding = self.proj(embedding)\n", + "\n", + " # add the positional encoding account for batch size\n", + " pos_encoding = self.get_pos_encoding(\n", + " self.out_channels, embedding.shape[1]).to(x.device)\n", + "\n", + " embedding = embedding + pos_encoding\n", + "\n", + " # apply the dropout\n", + " embedding = F.dropout(embedding, p=0.1, training=self.training)\n", + "\n", + "\n", + " if self.return_patches:\n", + " patches = self.get_patches(x, self.patch_size)\n", + " patches = patches.reshape(\n", + " x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size)\n", + "\n", + " return embedding, patches\n", + " else:\n", + " return embedding\n", + " \n", + "\n", + "\n", + "# create the transformer class\n", + "class MyBasicTransformer(nn.Module):\n", + " def __init__(self) -> None:\n", + " super(MyBasicTransformer, self).__init__()\n", + " self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=512, return_patches=True)\n", + " # use MultiHeadAttentionFlattened class\n", + " self.attention = MultiHeadAttentionFlattened(num_features=512, embedding_size=512, heads=8)\n", + " self.norm = nn.LayerNorm(512)\n", + " self.mlp = nn.Sequential(\n", + " nn.Linear(32768, 2048), # input size should be 4096 instead of 512\n", + " nn.GELU(),\n", + " nn.Linear(2048, 512)\n", + " )\n", + " self.norm2 = nn.LayerNorm(512)\n", + " self.classify = nn.Linear(512, 10)\n", + "\n", + " def flatten(self, x):\n", + " batch_size, N, E = x.size()\n", + " return x.view(batch_size, -1)\n", + "\n", + " def forward(self, x):\n", + " embedding, patches = self.embedding(x)\n", + " out = self.attention(embedding)\n", + " out = self.norm(out)\n", + " out = self.flatten(out)\n", + " out = self.mlp(out)\n", + " out = self.norm2(out)\n", + " out = self.classify(out)\n", + " return out, patches\n", + " \n", + "\n", + "# create the dataset\n", + "transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", + "])\n", + "\n", + "trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)\n", + "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)\n", + "\n", + "testset = CIFAR10(root='./data', train=False, download=True, transform=transform)\n", + "testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)\n", + "\n", + "# create the model\n", + "model = MyBasicTransformer()\n", + "model = model.to('cuda')\n", + "\n", + "# create the optimizer\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", + "\n", + "# create the loss function\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "# train the model\n", + "for epoch in range(10):\n", + " running_loss = 0.0\n", + " for i, data in enumerate(trainloader, 0):\n", + " inputs, labels = data\n", + " inputs = inputs.to('cuda')\n", + " labels = labels.to('cuda')\n", + " optimizer.zero_grad()\n", + " outputs, _ = model(inputs)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + " running_loss += loss.item()\n", + " if i % 100 == 99:\n", + " print('[%d, %5d] loss: %.3f' %\n", + " (epoch + 1, i + 1, running_loss / 100))\n", + " running_loss = 0.0\n", + "\n", + "print('Finished Training')\n", + "\n", + "# test the model\n", + "correct = 0\n", + "total = 0\n", + "with torch.no_grad():\n", + " for data in testloader:\n", + " images, labels = data\n", + " images = images.to('cuda')\n", + " labels = labels.to('cuda')\n", + " outputs, _ = model(images)\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + "print('Accuracy of the network on the 10000 test images: %d %%' % (\n", + " 100 * correct / total))\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 8, 32, 32])\n", + "torch.Size([1, 128, 8])\n" + ] + } + ], + "source": [ + "from torch import nn\n", + "import torch\n", + "\n", + "# create a transformer class model\n", + "from torch import nn\n", + "import torch\n", + "from torch.nn import functional as F\n", + "import math\n", + "\n", + "from torchvision.datasets import CIFAR10\n", + "# import torchvision transforms\n", + "from torchvision import transforms\n", + "import torchvision\n", + "\n", + "# create the same class but parameterized\n", + "class SelfAttentionParam(nn.Module):\n", + " def __init__(self, in_channels, out_channels, kernel_size=1) -> None:\n", + " super(SelfAttentionParam, self).__init__()\n", + " self.query = nn.Conv2d(\n", + " in_channels=in_channels, \n", + " out_channels=in_channels, \n", + " kernel_size=kernel_size\n", + " )\n", + " self.key = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)\n", + " self.value = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)\n", + "\n", + " def forward(self, x):\n", + " batch_size, C, H, W = x.size()\n", + " Q = self.query(x).view(batch_size, -1, H*W)\n", + " K = self.key(x).view(batch_size, -1, H*W)\n", + " V = self.value(x).view(batch_size, -1, H*W)\n", + " # Q, K, V = [batch_size, N, N]\n", + " energy = torch.bmm(Q.permute(0, 2, 1), K)\n", + " # energy = [batch_size, N, N]\n", + " attention = torch.softmax(energy, dim=-1)\n", + " out = torch.bmm(V, attention.permute(0, 2, 1))\n", + " out = out.view(batch_size, C, H, W)\n", + " return out\n", + "\n", + "# create multiple heads\n", + "class MultiHeadAttention(nn.Module):\n", + " def __init__(self, in_channels, out_channels, kernel_size=1, heads=8) -> None:\n", + " super(MultiHeadAttention, self).__init__()\n", + " self.heads = heads\n", + " self.attention = nn.ModuleList([SelfAttentionParam(in_channels, out_channels, kernel_size) for _ in range(heads)])\n", + "\n", + " def forward(self, x):\n", + " batch_size, C, H, W = x.size()\n", + " out = torch.zeros(batch_size, C, H, W).to(x.device)\n", + " for i in range(self.heads):\n", + " out += self.attention[i](x)\n", + " return out\n", + "\n", + "\n", + "\n", + "\n", + "# get sample data for MultiHeadAttention\n", + "x = torch.randn(1, 8, 32, 32)\n", + "# create the model\n", + "model = MultiHeadAttention(in_channels=x.shape[1], out_channels=64, kernel_size=1, heads=8)\n", + "# forward pass\n", + "out = model(x)\n", + "print(out.shape)\n", + "\n", + "from torch import nn\n", + "model = nn.MultiheadAttention(embed_dim=8, num_heads=8)\n", + "\n", + "# create sample data that matches the input shape\n", + "x = torch.randn(1, 128, 8)\n", + "\n", + "# forward pass\n", + "out, _ = model(x, x, x)\n", + "\n", + "print(out.shape)\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create the embedding class with class token and positional encoding\n", + "class Embedding(nn.Module):\n", + " def __init__(self, patch_size, in_channels, out_channels, return_patches=False) -> None:\n", + " super(Embedding, self).__init__()\n", + " self.patch_size = patch_size\n", + " self.return_patches = return_patches\n", + " self.patch_embedding = nn.Conv2d(\n", + " in_channels=in_channels, \n", + " out_channels=out_channels, \n", + " kernel_size=patch_size, \n", + " stride=patch_size\n", + " )\n", + " self.class_embedding = nn.Parameter(torch.randn(1, out_channels, 1, 1))\n", + " self.positional_embedding = nn.Parameter(torch.randn(1, out_channels, 32, 32))\n", + " self.norm = nn.LayerNorm(out_channels)\n", + "\n", + " def forward(self, x):\n", + " batch_size, C, H, W = x.size()\n", + " patches = self.patch_embedding(x)\n", + " patches = patches.flatten(2).transpose(1, 2)\n", + " embedding = self.class_embedding.expand(batch_size, -1, -1, -1).flatten(2).transpose(1, 2) + patches\n", + "\n", + "\n", + "\n", + " embedding = embedding + self.positional_embedding\n", + " embedding = self.norm(embedding)\n", + " if self.return_patches:\n", + " return embedding, patches\n", + " else:\n", + " return embedding\n" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([32, 5, 16])\n" + ] + } + ], + "source": [ + "class SelfAttentionParam(nn.Module):\n", + " def __init__(self, in_features, out_features) -> None:\n", + " super(SelfAttentionParam, self).__init__()\n", + " self.query = nn.Linear(in_features, out_features)\n", + " self.key = nn.Linear(in_features, out_features)\n", + " self.value = nn.Linear(in_features, out_features)\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_embeddings, embedding_dim = x.size()\n", + " Q = self.query(x).view(batch_size, num_embeddings, -1)\n", + " K = self.key(x).view(batch_size, num_embeddings, -1)\n", + " V = self.value(x).view(batch_size, num_embeddings, -1)\n", + " # Q, K, V = [batch_size, num_embeddings, embedding_dim]\n", + " energy = torch.bmm(Q, K.permute(0, 2, 1))\n", + " # energy = [batch_size, num_embeddings, num_embeddings]\n", + " attention = torch.softmax(energy, dim=-1)\n", + " out = torch.bmm(attention, V)\n", + " out = out.view(batch_size, num_embeddings, -1)\n", + " return out\n", + "\n", + "\n", + "class MultiHeadAttention(nn.Module):\n", + " def __init__(self, embedd_size, heads=8) -> None:\n", + " super(MultiHeadAttention, self).__init__()\n", + " self.heads = heads\n", + " self.attention = nn.ModuleList([SelfAttentionParam(embedd_size, embedd_size) for _ in range(heads)])\n", + " self.projection = nn.Linear(heads * embedd_size, embedd_size)\n", + "\n", + " def forward(self, x):\n", + " out = [self.attention[i](x) for i in range(self.heads)]\n", + " out = torch.cat(out, dim=2)\n", + "\n", + " out = self.projection(out)\n", + " return out\n", + " \n", + "\n", + "# create a sample data\n", + "x = torch.randn(32, 10, 16)\n", + "\n", + "# create the model\n", + "model = MultiHeadAttention(embedd_size=16, heads=4)\n", + "\n", + "# forward pass\n", + "out = model(x)\n", + "\n", + "print(out.shape)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}