diff --git a/cifar_example/cifar_example_transformer.py b/cifar_example/cifar_example_transformer.py
index f59818a97dde2841deabcc74b0c55618531c1bb2..6e77b5d024ae550c4a0bcec1fe310e6a6cf2231a 100644
--- a/cifar_example/cifar_example_transformer.py
+++ b/cifar_example/cifar_example_transformer.py
@@ -49,16 +49,6 @@ def train(model, device, train_loader, optimizer, epoch):
             # log to wandb
             wandb.log({"loss": loss.item()})
             wandb.log({"epoch": epoch})
-            # get all the parameters of the model
-            params = list(model.named_parameters())
-            # log the gradients
-            for name, param in params:
-                wandb.log({name + "_grad": wandb.Histogram(param.grad.cpu().numpy())})
-
-            # log the weights
-            for name, param in params:
-                wandb.log({name + "_weights": wandb.Histogram(param.detach().cpu().numpy())})
-            
             # compute the accuracy
             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
             correct = pred.eq(target.view_as(pred)).sum().item()
@@ -198,7 +188,6 @@ class SelfAttentionParam(nn.Module):
         out = out.view(batch_size, num_embeddings, -1)
         return out
 
-
 class MultiHeadAttention(nn.Module):
     def __init__(self, embedd_size, heads=8) -> None:
         super(MultiHeadAttention, self).__init__()
@@ -215,12 +204,8 @@ class MultiHeadAttention(nn.Module):
     
 
 class MyTransformerLayer(nn.Module):
-    def __init__(self, d_model, nhead, dropout=0.1, batch_first=False):
+    def __init__(self, d_model, nhead, dropout=0.1):
         super(MyTransformerLayer, self).__init__()
-        # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
-        print("d_model", d_model)
-        print("nhead", nhead)
-
         self.self_attn = MultiHeadAttention(embedd_size=d_model, heads=nhead)
         self.linear1 = nn.Linear(d_model, d_model)
         self.dropout = nn.Dropout(dropout)
@@ -230,9 +215,7 @@ class MyTransformerLayer(nn.Module):
         self.dropout1 = nn.Dropout(dropout)
         self.dropout2 = nn.Dropout(dropout)
 
-    def forward(self, src, src_mask=None, src_key_padding_mask=None):
-        # src2 = self.self_attn(src, src, src, attn_mask=src_mask,
-        #                       key_padding_mask=src_key_padding_mask)[0]
+    def forward(self, src):
         src2 = self.self_attn(src)
         src = src + self.dropout1(src2)
         src = self.norm1(src)
@@ -245,7 +228,7 @@ class MyTransformerLayer(nn.Module):
 class PthBasedTransformer(nn.Module):
     def __init__(self, embedding_size=64) -> None:
         super().__init__()
-        self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True)
+        self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True)
         # self.self_attention = TransformerEncoderLayer(
         #     d_model=embedding_size, 
         #     nhead=16, 
@@ -254,7 +237,7 @@ class PthBasedTransformer(nn.Module):
         #     batch_first=True
         #     )
         
-        self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3, batch_first=True)
+        self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3)
         self.fc = nn.Linear(embedding_size, 10)
         
     def forward(self, x):