diff --git a/cifar_example/cifar_example_transformer.py b/cifar_example/cifar_example_transformer.py
index 505317359038572f8e8a4f8015d372513e061a3c..45f47524c462b5d53f682db53bb1b9b4fab09def 100644
--- a/cifar_example/cifar_example_transformer.py
+++ b/cifar_example/cifar_example_transformer.py
@@ -8,6 +8,8 @@ import torch.nn.functional as F
 import torch.optim as optim
 import torchvision
 from torchvision import datasets, transforms
+import wandb
+
 
 # import resnet18 from trochvision
 from torchvision.models import resnet18
@@ -21,10 +23,10 @@ train_data = CIFAR10(root='./data', train=True, download=True, transform=transfo
 test_data = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
 
 # get the train loader
-train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=True)
+train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
 
 # get the test loader
-test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
+test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)
 criterion = torch.nn.CrossEntropyLoss()
 
 # train the model
@@ -42,6 +44,19 @@ def train(model, device, train_loader, optimizer, epoch):
                 epoch, batch_idx * len(data), len(train_loader.dataset),
                 100. * batch_idx / len(train_loader), loss.item()))
 
+        # log to wandb
+        wandb.log({"loss": loss.item()})
+        wandb.log({"epoch": epoch})
+
+        # compute the accuracy
+        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
+        correct = pred.eq(target.view_as(pred)).sum().item()
+        accuracy = correct / len(data)
+
+        # log to wandb
+        wandb.log({"accuracy": accuracy})
+
+
 # test the model
 def test(model, device, test_loader):
     model.eval()
@@ -84,22 +99,22 @@ class MyModel(nn.Module):
 
 #### define my transformer model
 
+# define embedding class
 class Embedding(nn.Module):
-    def __init__(self, patch_size, in_channels, out_channels, device='cpu', return_patches=False, extra_token=False):
+    def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):
         super(Embedding, self).__init__()
         self.patch_size = patch_size
         self.in_channels = in_channels
         self.out_channels = out_channels
-        self.device = device
         self.return_patches = return_patches
         self.classify = extra_token
         self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
         self.norm = nn.LayerNorm(out_channels)
-        self.extra_token = None  # initialize extra_token tensor
+
 
     def get_patches(self, x, patch_size=8):
         # get the patches
-        patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
+        patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size).to(x.device)
 
         return patches
 
@@ -121,12 +136,6 @@ class Embedding(nn.Module):
         patches = self.get_patches(x, patch_size=self.patch_size)
         # flatten the patches
         patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)
-        # add extra embedding token if classification is needed
-        if self.classify:
-            self.extra_token = torch.rand(1, self.in_channels, self.patch_size, self.patch_size)
-            self.extra_token = self.extra_token.to(x.device)  # move extra_token to the same device as x
-            patches = torch.cat((self.extra_token, patches), dim=0)
-   
         # get the embedding
         embedding = self.conv(patches)
         # flatten the embedding
@@ -135,14 +144,22 @@ class Embedding(nn.Module):
         embedding = self.norm(embedding)
         # add the positional encoding
         pos_encoding = self.get_pos_encoding(self.out_channels, embedding.shape[0])
-        pos_encoding = pos_encoding.to(x.device)
-        embedding = embedding + pos_encoding
+        embedding = embedding + pos_encoding.to(x.device)
+
+        # reshape the embedding
+        embedding = embedding.reshape(x.shape[0], -1, self.out_channels)
+
+        if self.classify:
+            # add the classification token
+            classification_token = torch.rand(x.shape[0], 1, self.out_channels).to(x.device)
+            embedding = torch.cat((classification_token, embedding), dim=1)
         
         if self.return_patches:
             return embedding, patches
         else:
             return embedding
 
+
 # define transformer class
 class SelfAttention(nn.Module):
     def __init__(self, embed_dim):
@@ -178,26 +195,28 @@ from torch.nn import TransformerEncoderLayer
 
 
 class PthBasedTransformer(nn.Module):
-    def __init__(self) -> None:
+    def __init__(self, embedding_size=64) -> None:
         super().__init__()
-        self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=8, return_patches=True, extra_token=True)
-        self.self_attention = TransformerEncoderLayer(d_model=8, nhead=8)
-        self.fc = nn.Linear(8, 10)
+        self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True)
+        self.self_attention = TransformerEncoderLayer(
+            d_model=embedding_size, 
+            nhead=16, 
+            dim_feedforward=embedding_size*4, 
+            dropout=0.3
+            )
+        self.fc = nn.Linear(embedding_size, 10)
         
-   
-
     def forward(self, x):
         embedding, patches = self.embedding(x)
         context = self.self_attention(embedding)
-    
-
         # get the first token
         context = context[:, 0, :]
 
+        # context = context.mean(dim=1)
+
         # get the classification
         context = self.fc(context)
-        
-
+    
         return context
 
 
@@ -260,10 +279,29 @@ def main(train_model=False, model_type="resnet"):
 
 if __name__ == '__main__':
     train_model = True
-    model_type = "cnn"
+    # model_type = "cnn"
     # model_type = "resnet"
     model_type = "pth_transformer"
 
+    # Create a config object for wandb
+    config = {
+        'model_type': model_type,
+        'batch_size': 64,
+        'test_batch_size': 1000,
+        'epochs': 10,
+        'lr': 0.01,
+        'momentum': 0.5,
+        'no_cuda': False,
+        'seed': 1,
+        'log_interval': 10,
+        'patch_size': 8,
+        'in_channels': 3,
+        'out_channels': 64,
+        'extra_token': True
+    }
+
+    # add model type to wandb
+    wandb.init(project="cifar10_example_transformer", entity="maciej-wielgosz-nibio", config=config)
 
     main(
         train_model=train_model,