diff --git a/cifar_example/cifar_example_transformer.py b/cifar_example/cifar_example_transformer.py
index 45f47524c462b5d53f682db53bb1b9b4fab09def..f59818a97dde2841deabcc74b0c55618531c1bb2 100644
--- a/cifar_example/cifar_example_transformer.py
+++ b/cifar_example/cifar_example_transformer.py
@@ -14,6 +14,8 @@ import wandb
 # import resnet18 from trochvision
 from torchvision.models import resnet18
 
+from torch.nn import TransformerEncoderLayer
+
 # import cifar10 dataset
 from torchvision.datasets import CIFAR10
 
@@ -44,17 +46,26 @@ def train(model, device, train_loader, optimizer, epoch):
                 epoch, batch_idx * len(data), len(train_loader.dataset),
                 100. * batch_idx / len(train_loader), loss.item()))
 
-        # log to wandb
-        wandb.log({"loss": loss.item()})
-        wandb.log({"epoch": epoch})
-
-        # compute the accuracy
-        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
-        correct = pred.eq(target.view_as(pred)).sum().item()
-        accuracy = correct / len(data)
+            # log to wandb
+            wandb.log({"loss": loss.item()})
+            wandb.log({"epoch": epoch})
+            # get all the parameters of the model
+            params = list(model.named_parameters())
+            # log the gradients
+            for name, param in params:
+                wandb.log({name + "_grad": wandb.Histogram(param.grad.cpu().numpy())})
+
+            # log the weights
+            for name, param in params:
+                wandb.log({name + "_weights": wandb.Histogram(param.detach().cpu().numpy())})
+            
+            # compute the accuracy
+            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
+            correct = pred.eq(target.view_as(pred)).sum().item()
+            accuracy = correct / len(data)
 
-        # log to wandb
-        wandb.log({"accuracy": accuracy})
+            # log to wandb
+            wandb.log({"accuracy": accuracy})
 
 
 # test the model
@@ -101,20 +112,23 @@ class MyModel(nn.Module):
 
 # define embedding class
 class Embedding(nn.Module):
-    def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):
+    def __init__(self, patch_size, in_channels, out_channels, batch_size=1, return_patches=False, extra_token=False):
         super(Embedding, self).__init__()
         self.patch_size = patch_size
         self.in_channels = in_channels
         self.out_channels = out_channels
         self.return_patches = return_patches
+        self.class_embedding = nn.Parameter(torch.randn(1, out_channels))
         self.classify = extra_token
-        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
+        self.patch_conv = nn.Conv2d(
+            in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
         self.norm = nn.LayerNorm(out_channels)
-
+        self.proj = nn.Linear(out_channels, out_channels)
 
     def get_patches(self, x, patch_size=8):
         # get the patches
-        patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size).to(x.device)
+        patches = x.unfold(2, patch_size, patch_size).unfold(
+            3, patch_size, patch_size).to(x.device)
 
         return patches
 
@@ -130,80 +144,117 @@ class Embedding(nn.Module):
         pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)
 
         return pos_encoding
-        
+
     def forward(self, x):
-        # get the patches
-        patches = self.get_patches(x, patch_size=self.patch_size)
-        # flatten the patches
-        patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)
-        # get the embedding
-        embedding = self.conv(patches)
-        # flatten the embedding
-        embedding = embedding.reshape(-1, self.out_channels)
-        # normalize the embedding
-        embedding = self.norm(embedding)
-        # add the positional encoding
-        pos_encoding = self.get_pos_encoding(self.out_channels, embedding.shape[0])
-        embedding = embedding + pos_encoding.to(x.device)
 
-        # reshape the embedding
+        embedding = self.patch_conv(x)
+
+        # flatten the embedding
         embedding = embedding.reshape(x.shape[0], -1, self.out_channels)
 
         if self.classify:
-            # add the classification token
-            classification_token = torch.rand(x.shape[0], 1, self.out_channels).to(x.device)
-            embedding = torch.cat((classification_token, embedding), dim=1)
-        
-        if self.return_patches:
-            return embedding, patches
-        else:
-            return embedding
+            class_embedding = self.class_embedding.repeat(x.shape[0], 1, 1)
+            embedding = torch.cat([class_embedding, embedding], dim=1)
 
+        # normalize the embedding
+        embedding = self.norm(embedding)
 
-# define transformer class
-class SelfAttention(nn.Module):
-    def __init__(self, embed_dim):
-        super().__init__()
+        # project the embedding
+        embedding = self.proj(embedding)
 
-        # Query, Key, Value weight matrices
-        self.qkv_linear = nn.Linear(embed_dim, embed_dim * 3)
+        # add the positional encoding account for batch size
+        pos_encoding = self.get_pos_encoding(
+            self.out_channels, embedding.shape[1]).to(x.device)
 
-        # Final output weight matrix
-        self.output_linear = nn.Linear(embed_dim, embed_dim)
-    
-    def forward(self, x):
-        batch_size, seq_len, embed_dim = x.size()
+        embedding = embedding + pos_encoding
 
-        # Create queries, keys, and values
-        qkv = self.qkv_linear(x)
-        q, k, v = torch.split(qkv, embed_dim, dim=-1)
 
-        # Compute attention scores
-        scores = torch.matmul(q, k.transpose(-2, -1)) / (embed_dim ** 0.5)
-        attn = torch.softmax(scores, dim=-1)
+        if self.return_patches:
+            patches = self.get_patches(x, self.patch_size)
+            patches = patches.reshape(
+                x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size)
 
-        # Apply attention to values
-        weighted_values = torch.matmul(attn, v)
+            return embedding, patches
+        else:
+            return embedding
 
-        # Apply final output weight matrix
-        output = self.output_linear(weighted_values)
+class SelfAttentionParam(nn.Module):
+    def __init__(self, in_features, out_features) -> None:
+        super(SelfAttentionParam, self).__init__()
+        self.query = nn.Linear(in_features, out_features)
+        self.key = nn.Linear(in_features, out_features)
+        self.value = nn.Linear(in_features, out_features)
 
-        return output
+    def forward(self, x):
+        batch_size, num_embeddings, embedding_dim = x.size()
+        Q = self.query(x).view(batch_size, num_embeddings, -1)
+        K = self.key(x).view(batch_size, num_embeddings, -1)
+        V = self.value(x).view(batch_size, num_embeddings, -1)
+        # Q, K, V = [batch_size, num_embeddings, embedding_dim]
+        energy = torch.bmm(Q, K.permute(0, 2, 1))
+        # energy = [batch_size, num_embeddings, num_embeddings]
+        attention = torch.softmax(energy, dim=-1)
+        out = torch.bmm(attention, V)
+        out = out.view(batch_size, num_embeddings, -1)
+        return out
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(self, embedd_size, heads=8) -> None:
+        super(MultiHeadAttention, self).__init__()
+        self.heads = heads
+        self.attention = nn.ModuleList([SelfAttentionParam(embedd_size, embedd_size) for _ in range(heads)])
+        self.projection = nn.Linear(heads * embedd_size, embedd_size)
 
+    def forward(self, x):
+        out = [self.attention[i](x) for i in range(self.heads)]
+        out = torch.cat(out, dim=2)
 
-from torch.nn import TransformerEncoderLayer
+        out = self.projection(out)
+        return out
+    
+
+class MyTransformerLayer(nn.Module):
+    def __init__(self, d_model, nhead, dropout=0.1, batch_first=False):
+        super(MyTransformerLayer, self).__init__()
+        # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
+        print("d_model", d_model)
+        print("nhead", nhead)
+
+        self.self_attn = MultiHeadAttention(embedd_size=d_model, heads=nhead)
+        self.linear1 = nn.Linear(d_model, d_model)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(d_model, d_model)
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+    def forward(self, src, src_mask=None, src_key_padding_mask=None):
+        # src2 = self.self_attn(src, src, src, attn_mask=src_mask,
+        #                       key_padding_mask=src_key_padding_mask)[0]
+        src2 = self.self_attn(src)
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+        return src
 
 
 class PthBasedTransformer(nn.Module):
     def __init__(self, embedding_size=64) -> None:
         super().__init__()
         self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True)
-        self.self_attention = TransformerEncoderLayer(
-            d_model=embedding_size, 
-            nhead=16, 
-            dim_feedforward=embedding_size*4, 
-            dropout=0.3
-            )
+        # self.self_attention = TransformerEncoderLayer(
+        #     d_model=embedding_size, 
+        #     nhead=16, 
+        #     dim_feedforward=embedding_size*4, 
+        #     dropout=0.3,
+        #     batch_first=True
+        #     )
+        
+        self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3, batch_first=True)
         self.fc = nn.Linear(embedding_size, 10)
         
     def forward(self, x):
@@ -219,23 +270,6 @@ class PthBasedTransformer(nn.Module):
     
         return context
 
-
-class MyTransformer(nn.Module):
-    def __init__(self):
-        super(MyTransformer, self).__init__()
-        self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=64)
-        self.attention = SelfAttention(embed_dim=64)
-        self.fc1 = nn.Linear(64, 10)
-
-
-    def forward(self, x):
-        x = self.embedding(x)
-        x = self.attention(x)
-        x = self.fc1(x)
-        return x
-
-
-
 # main function
 def main(train_model=False, model_type="resnet"):
     # use cuda if available
@@ -250,9 +284,6 @@ def main(train_model=False, model_type="resnet"):
     elif model_type == "cnn":
         # get the cnn model
         model = MyModel().to(device)
-    elif model_type == "transformer":
-        # get the transformer model
-        model = MyTransformer().to(device)
     elif model_type == "pth_transformer":
         # get the transformer model
         model = PthBasedTransformer().to(device)
@@ -283,6 +314,7 @@ if __name__ == '__main__':
     # model_type = "resnet"
     model_type = "pth_transformer"
 
+
     # Create a config object for wandb
     config = {
         'model_type': model_type,
diff --git a/cifar_example/self_attention.ipynb b/cifar_example/self_attention.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..a223d112a9797344c7c5b2934a9e370c50f0db99
--- /dev/null
+++ b/cifar_example/self_attention.ipynb
@@ -0,0 +1,495 @@
+{
+ "cells": [
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Simple self-attention implementation\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from torch import nn\n",
+    "import torch\n",
+    "\n",
+    "# create a transformer class model\n",
+    "from torch import nn\n",
+    "import torch\n",
+    "from torch.nn import functional as F\n",
+    "import math\n",
+    "\n",
+    "from torchvision.datasets import CIFAR10\n",
+    "# import torchvision transforms\n",
+    "from torchvision import transforms\n",
+    "\n",
+    "# create the same class but parameterized\n",
+    "class SelfAttentionParam(nn.Module):\n",
+    "    def __init__(self, in_channels, out_channels, kernel_size=1) -> None:\n",
+    "        super(SelfAttentionParam, self).__init__()\n",
+    "        self.query = nn.Conv2d(\n",
+    "            in_channels=in_channels, \n",
+    "            out_channels=out_channels, \n",
+    "            kernel_size=kernel_size\n",
+    "            )\n",
+    "        self.key = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)\n",
+    "        self.value = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        batch_size, C, H, W = x.size()\n",
+    "        Q = self.query(x).view(batch_size, -1, H*W)\n",
+    "        K = self.key(x).view(batch_size, -1, H*W)\n",
+    "        V = self.value(x).view(batch_size, -1, H*W)\n",
+    "        # Q, K, V = [batch_size, N, N]\n",
+    "        energy = torch.bmm(Q.permute(0, 2, 1), K)\n",
+    "        # energy = [batch_size, N, N]\n",
+    "        attention = torch.softmax(energy, dim=-1)\n",
+    "        out = torch.bmm(V, attention.permute(0, 2, 1))\n",
+    "        out = out.view(batch_size, C, H, W)\n",
+    "        return out\n",
+    "\n",
+    "# create multiple heads\n",
+    "class MultiHeadAttention(nn.Module):\n",
+    "    def __init__(self, in_channels, out_channels, kernel_size=1, heads=8) -> None:\n",
+    "        super(MultiHeadAttention, self).__init__()\n",
+    "        self.heads = heads\n",
+    "        self.attention = nn.ModuleList([SelfAttentionParam(in_channels, out_channels, kernel_size) for _ in range(heads)])\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        batch_size, C, H, W = x.size()\n",
+    "        out = torch.zeros(batch_size, C, H, W).to(x.device)\n",
+    "        for i in range(self.heads):\n",
+    "            out += self.attention[i](x)\n",
+    "        return out\n",
+    "    \n",
+    "\n",
+    "# create the same class but assuming the input is already flattened (N, num_features, embedding_size)\n",
+    "class SelfAttentionParamFlattened(nn.Module):\n",
+    "    def __init__(self, num_features, embedding_size) -> None:\n",
+    "        super(SelfAttentionParamFlattened, self).__init__()\n",
+    "        self.query = nn.Linear(num_features, embedding_size)\n",
+    "        self.key = nn.Linear(num_features, embedding_size)\n",
+    "        self.value = nn.Linear(num_features, embedding_size)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        batch_size, N, E = x.size()\n",
+    "        Q = self.query(x.view(batch_size*N, -1)).view(batch_size, N, -1)\n",
+    "        K = self.key(x.view(batch_size*N, -1)).view(batch_size, N, -1)\n",
+    "        V = self.value(x.view(batch_size*N, -1)).view(batch_size, N, -1)\n",
+    "\n",
+    "        # Q, K, V = [batch_size, N, E]\n",
+    "        energy = torch.bmm(Q, K.transpose(1,2)) / math.sqrt(E)\n",
+    "        # energy = [batch_size, N, N]\n",
+    "        attention = torch.softmax(energy, dim=-1)\n",
+    "        out = torch.bmm(attention, V)\n",
+    "        return out\n",
+    "    \n",
+    "# create multiple heads\n",
+    "class MultiHeadAttentionFlattened(nn.Module):\n",
+    "    def __init__(self, num_features, embedding_size, heads=8) -> None:\n",
+    "        super(MultiHeadAttentionFlattened, self).__init__()\n",
+    "        self.heads = heads\n",
+    "        self.attention = nn.ModuleList([SelfAttentionParamFlattened(num_features, embedding_size) for _ in range(heads)])\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        batch_size, N, E = x.size()\n",
+    "        out = torch.zeros(batch_size, N, E).to(x.device)\n",
+    "        for i in range(self.heads):\n",
+    "            out += self.attention[i](x)\n",
+    "        return out\n",
+    "\n",
+    "\n",
+    "class Embedding(nn.Module):\n",
+    "    def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):\n",
+    "        super(Embedding, self).__init__()\n",
+    "        self.patch_size = patch_size\n",
+    "        self.in_channels = in_channels\n",
+    "        self.out_channels = out_channels\n",
+    "        self.return_patches = return_patches\n",
+    "        self.class_embedding = nn.Parameter(torch.randn(1, out_channels))\n",
+    "        self.classify = extra_token\n",
+    "        self.patch_conv = nn.Conv2d(\n",
+    "            in_channels, out_channels, kernel_size=patch_size, stride=patch_size)\n",
+    "        self.norm = nn.LayerNorm(out_channels)\n",
+    "        self.proj = nn.Linear(out_channels, out_channels)\n",
+    "\n",
+    "    def get_patches(self, x, patch_size=8):\n",
+    "        # get the patches\n",
+    "        patches = x.unfold(2, patch_size, patch_size).unfold(\n",
+    "            3, patch_size, patch_size).to(x.device)\n",
+    "\n",
+    "        return patches\n",
+    "\n",
+    "    def get_pos_encoding(self, d_emb, max_len):\n",
+    "        pos = torch.arange(0, max_len).float().unsqueeze(1)\n",
+    "        i = torch.arange(0, d_emb, 2).float()\n",
+    "\n",
+    "        div = torch.exp(-i * math.log(10000) / d_emb)\n",
+    "\n",
+    "        sin = torch.sin(pos * div)\n",
+    "        cos = torch.cos(pos * div)\n",
+    "\n",
+    "        pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n",
+    "\n",
+    "        return pos_encoding\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "\n",
+    "        embedding = self.patch_conv(x)\n",
+    "\n",
+    "        # flatten the embedding\n",
+    "        embedding = embedding.reshape(x.shape[0], -1, self.out_channels)\n",
+    "\n",
+    "        if self.classify:\n",
+    "            class_embedding = self.class_embedding.repeat(x.shape[0], 1, 1)\n",
+    "            embedding = torch.cat([class_embedding, embedding], dim=1)\n",
+    "\n",
+    "        # normalize the embedding\n",
+    "        embedding = self.norm(embedding)\n",
+    "\n",
+    "        # project the embedding\n",
+    "        embedding = self.proj(embedding)\n",
+    "\n",
+    "        # add the positional encoding account for batch size\n",
+    "        pos_encoding = self.get_pos_encoding(\n",
+    "            self.out_channels, embedding.shape[1]).to(x.device)\n",
+    "\n",
+    "        embedding = embedding + pos_encoding\n",
+    "\n",
+    "        # apply the dropout\n",
+    "        embedding = F.dropout(embedding, p=0.1, training=self.training)\n",
+    "\n",
+    "\n",
+    "        if self.return_patches:\n",
+    "            patches = self.get_patches(x, self.patch_size)\n",
+    "            patches = patches.reshape(\n",
+    "                x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size)\n",
+    "\n",
+    "            return embedding, patches\n",
+    "        else:\n",
+    "            return embedding\n",
+    "        \n",
+    "\n",
+    "\n",
+    "# create the transformer class\n",
+    "class MyBasicTransformer(nn.Module):\n",
+    "    def __init__(self) -> None:\n",
+    "        super(MyBasicTransformer, self).__init__()\n",
+    "        self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=512, return_patches=True)\n",
+    "        # use MultiHeadAttentionFlattened class\n",
+    "        self.attention = MultiHeadAttentionFlattened(num_features=512, embedding_size=512, heads=8)\n",
+    "        self.norm = nn.LayerNorm(512)\n",
+    "        self.mlp = nn.Sequential(\n",
+    "            nn.Linear(32768, 2048),  # input size should be 4096 instead of 512\n",
+    "            nn.GELU(),\n",
+    "            nn.Linear(2048, 512)\n",
+    "        )\n",
+    "        self.norm2 = nn.LayerNorm(512)\n",
+    "        self.classify = nn.Linear(512, 10)\n",
+    "\n",
+    "    def flatten(self, x):\n",
+    "        batch_size, N, E = x.size()\n",
+    "        return x.view(batch_size, -1)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        embedding, patches = self.embedding(x)\n",
+    "        out = self.attention(embedding)\n",
+    "        out = self.norm(out)\n",
+    "        out = self.flatten(out)\n",
+    "        out = self.mlp(out)\n",
+    "        out = self.norm2(out)\n",
+    "        out = self.classify(out)\n",
+    "        return out, patches\n",
+    "    \n",
+    "\n",
+    "# create the dataset\n",
+    "transform = transforms.Compose([\n",
+    "    transforms.ToTensor(),\n",
+    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
+    "])\n",
+    "\n",
+    "trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
+    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)\n",
+    "\n",
+    "testset = CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
+    "testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)\n",
+    "\n",
+    "# create the model\n",
+    "model = MyBasicTransformer()\n",
+    "model = model.to('cuda')\n",
+    "\n",
+    "# create the optimizer\n",
+    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
+    "\n",
+    "# create the loss function\n",
+    "criterion = nn.CrossEntropyLoss()\n",
+    "\n",
+    "# train the model\n",
+    "for epoch in range(10):\n",
+    "    running_loss = 0.0\n",
+    "    for i, data in enumerate(trainloader, 0):\n",
+    "        inputs, labels = data\n",
+    "        inputs = inputs.to('cuda')\n",
+    "        labels = labels.to('cuda')\n",
+    "        optimizer.zero_grad()\n",
+    "        outputs, _ = model(inputs)\n",
+    "        loss = criterion(outputs, labels)\n",
+    "        loss.backward()\n",
+    "        optimizer.step()\n",
+    "        running_loss += loss.item()\n",
+    "        if i % 100 == 99:\n",
+    "            print('[%d, %5d] loss: %.3f' %\n",
+    "                  (epoch + 1, i + 1, running_loss / 100))\n",
+    "            running_loss = 0.0\n",
+    "\n",
+    "print('Finished Training')\n",
+    "\n",
+    "# test the model\n",
+    "correct = 0\n",
+    "total = 0\n",
+    "with torch.no_grad():\n",
+    "    for data in testloader:\n",
+    "        images, labels = data\n",
+    "        images = images.to('cuda')\n",
+    "        labels = labels.to('cuda')\n",
+    "        outputs, _ = model(images)\n",
+    "        _, predicted = torch.max(outputs.data, 1)\n",
+    "        total += labels.size(0)\n",
+    "        correct += (predicted == labels).sum().item()\n",
+    "\n",
+    "print('Accuracy of the network on the 10000 test images: %d %%' % (\n",
+    "    100 * correct / total))\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "torch.Size([1, 8, 32, 32])\n",
+      "torch.Size([1, 128, 8])\n"
+     ]
+    }
+   ],
+   "source": [
+    "from torch import nn\n",
+    "import torch\n",
+    "\n",
+    "# create a transformer class model\n",
+    "from torch import nn\n",
+    "import torch\n",
+    "from torch.nn import functional as F\n",
+    "import math\n",
+    "\n",
+    "from torchvision.datasets import CIFAR10\n",
+    "# import torchvision transforms\n",
+    "from torchvision import transforms\n",
+    "import torchvision\n",
+    "\n",
+    "# create the same class but parameterized\n",
+    "class SelfAttentionParam(nn.Module):\n",
+    "    def __init__(self, in_channels, out_channels, kernel_size=1) -> None:\n",
+    "        super(SelfAttentionParam, self).__init__()\n",
+    "        self.query = nn.Conv2d(\n",
+    "            in_channels=in_channels, \n",
+    "            out_channels=in_channels, \n",
+    "            kernel_size=kernel_size\n",
+    "            )\n",
+    "        self.key = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)\n",
+    "        self.value = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        batch_size, C, H, W = x.size()\n",
+    "        Q = self.query(x).view(batch_size, -1, H*W)\n",
+    "        K = self.key(x).view(batch_size, -1, H*W)\n",
+    "        V = self.value(x).view(batch_size, -1, H*W)\n",
+    "        # Q, K, V = [batch_size, N, N]\n",
+    "        energy = torch.bmm(Q.permute(0, 2, 1), K)\n",
+    "        # energy = [batch_size, N, N]\n",
+    "        attention = torch.softmax(energy, dim=-1)\n",
+    "        out = torch.bmm(V, attention.permute(0, 2, 1))\n",
+    "        out = out.view(batch_size, C, H, W)\n",
+    "        return out\n",
+    "\n",
+    "# create multiple heads\n",
+    "class MultiHeadAttention(nn.Module):\n",
+    "    def __init__(self, in_channels, out_channels, kernel_size=1, heads=8) -> None:\n",
+    "        super(MultiHeadAttention, self).__init__()\n",
+    "        self.heads = heads\n",
+    "        self.attention = nn.ModuleList([SelfAttentionParam(in_channels, out_channels, kernel_size) for _ in range(heads)])\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        batch_size, C, H, W = x.size()\n",
+    "        out = torch.zeros(batch_size, C, H, W).to(x.device)\n",
+    "        for i in range(self.heads):\n",
+    "            out += self.attention[i](x)\n",
+    "        return out\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "# get sample data for MultiHeadAttention\n",
+    "x = torch.randn(1, 8, 32, 32)\n",
+    "# create the model\n",
+    "model = MultiHeadAttention(in_channels=x.shape[1], out_channels=64, kernel_size=1, heads=8)\n",
+    "# forward pass\n",
+    "out = model(x)\n",
+    "print(out.shape)\n",
+    "\n",
+    "from torch import nn\n",
+    "model = nn.MultiheadAttention(embed_dim=8, num_heads=8)\n",
+    "\n",
+    "# create sample data that matches the input shape\n",
+    "x = torch.randn(1, 128, 8)\n",
+    "\n",
+    "# forward pass\n",
+    "out, _ = model(x, x, x)\n",
+    "\n",
+    "print(out.shape)\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# create the embedding class with class token and positional encoding\n",
+    "class Embedding(nn.Module):\n",
+    "    def __init__(self, patch_size, in_channels, out_channels, return_patches=False) -> None:\n",
+    "        super(Embedding, self).__init__()\n",
+    "        self.patch_size = patch_size\n",
+    "        self.return_patches = return_patches\n",
+    "        self.patch_embedding = nn.Conv2d(\n",
+    "            in_channels=in_channels, \n",
+    "            out_channels=out_channels, \n",
+    "            kernel_size=patch_size, \n",
+    "            stride=patch_size\n",
+    "            )\n",
+    "        self.class_embedding = nn.Parameter(torch.randn(1, out_channels, 1, 1))\n",
+    "        self.positional_embedding = nn.Parameter(torch.randn(1, out_channels, 32, 32))\n",
+    "        self.norm = nn.LayerNorm(out_channels)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        batch_size, C, H, W = x.size()\n",
+    "        patches = self.patch_embedding(x)\n",
+    "        patches = patches.flatten(2).transpose(1, 2)\n",
+    "        embedding = self.class_embedding.expand(batch_size, -1, -1, -1).flatten(2).transpose(1, 2) + patches\n",
+    "\n",
+    "\n",
+    "\n",
+    "        embedding = embedding + self.positional_embedding\n",
+    "        embedding = self.norm(embedding)\n",
+    "        if self.return_patches:\n",
+    "            return embedding, patches\n",
+    "        else:\n",
+    "            return embedding\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 94,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "torch.Size([32, 5, 16])\n"
+     ]
+    }
+   ],
+   "source": [
+    "class SelfAttentionParam(nn.Module):\n",
+    "    def __init__(self, in_features, out_features) -> None:\n",
+    "        super(SelfAttentionParam, self).__init__()\n",
+    "        self.query = nn.Linear(in_features, out_features)\n",
+    "        self.key = nn.Linear(in_features, out_features)\n",
+    "        self.value = nn.Linear(in_features, out_features)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        batch_size, num_embeddings, embedding_dim = x.size()\n",
+    "        Q = self.query(x).view(batch_size, num_embeddings, -1)\n",
+    "        K = self.key(x).view(batch_size, num_embeddings, -1)\n",
+    "        V = self.value(x).view(batch_size, num_embeddings, -1)\n",
+    "        # Q, K, V = [batch_size, num_embeddings, embedding_dim]\n",
+    "        energy = torch.bmm(Q, K.permute(0, 2, 1))\n",
+    "        # energy = [batch_size, num_embeddings, num_embeddings]\n",
+    "        attention = torch.softmax(energy, dim=-1)\n",
+    "        out = torch.bmm(attention, V)\n",
+    "        out = out.view(batch_size, num_embeddings, -1)\n",
+    "        return out\n",
+    "\n",
+    "\n",
+    "class MultiHeadAttention(nn.Module):\n",
+    "    def __init__(self, embedd_size, heads=8) -> None:\n",
+    "        super(MultiHeadAttention, self).__init__()\n",
+    "        self.heads = heads\n",
+    "        self.attention = nn.ModuleList([SelfAttentionParam(embedd_size, embedd_size) for _ in range(heads)])\n",
+    "        self.projection = nn.Linear(heads * embedd_size, embedd_size)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        out = [self.attention[i](x) for i in range(self.heads)]\n",
+    "        out = torch.cat(out, dim=2)\n",
+    "\n",
+    "        out = self.projection(out)\n",
+    "        return out\n",
+    "    \n",
+    "\n",
+    "# create a sample data\n",
+    "x = torch.randn(32, 10, 16)\n",
+    "\n",
+    "# create the model\n",
+    "model = MultiHeadAttention(embedd_size=16, heads=4)\n",
+    "\n",
+    "# forward pass\n",
+    "out = model(x)\n",
+    "\n",
+    "print(out.shape)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.10"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}