Skip to content
Snippets Groups Projects
Commit 293c5f50 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

cleaning the basic cifar transfomer code

parent 20b84708
No related branches found
No related tags found
No related merge requests found
......@@ -49,16 +49,6 @@ def train(model, device, train_loader, optimizer, epoch):
# log to wandb
wandb.log({"loss": loss.item()})
wandb.log({"epoch": epoch})
# get all the parameters of the model
params = list(model.named_parameters())
# log the gradients
for name, param in params:
wandb.log({name + "_grad": wandb.Histogram(param.grad.cpu().numpy())})
# log the weights
for name, param in params:
wandb.log({name + "_weights": wandb.Histogram(param.detach().cpu().numpy())})
# compute the accuracy
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct = pred.eq(target.view_as(pred)).sum().item()
......@@ -198,7 +188,6 @@ class SelfAttentionParam(nn.Module):
out = out.view(batch_size, num_embeddings, -1)
return out
class MultiHeadAttention(nn.Module):
def __init__(self, embedd_size, heads=8) -> None:
super(MultiHeadAttention, self).__init__()
......@@ -215,12 +204,8 @@ class MultiHeadAttention(nn.Module):
class MyTransformerLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.1, batch_first=False):
def __init__(self, d_model, nhead, dropout=0.1):
super(MyTransformerLayer, self).__init__()
# self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
print("d_model", d_model)
print("nhead", nhead)
self.self_attn = MultiHeadAttention(embedd_size=d_model, heads=nhead)
self.linear1 = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
......@@ -230,9 +215,7 @@ class MyTransformerLayer(nn.Module):
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
# src2 = self.self_attn(src, src, src, attn_mask=src_mask,
# key_padding_mask=src_key_padding_mask)[0]
def forward(self, src):
src2 = self.self_attn(src)
src = src + self.dropout1(src2)
src = self.norm1(src)
......@@ -245,7 +228,7 @@ class MyTransformerLayer(nn.Module):
class PthBasedTransformer(nn.Module):
def __init__(self, embedding_size=64) -> None:
super().__init__()
self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True)
self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True)
# self.self_attention = TransformerEncoderLayer(
# d_model=embedding_size,
# nhead=16,
......@@ -254,7 +237,7 @@ class PthBasedTransformer(nn.Module):
# batch_first=True
# )
self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3, batch_first=True)
self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3)
self.fc = nn.Linear(embedding_size, 10)
def forward(self, x):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment