Skip to content
Snippets Groups Projects
Commit 293c5f50 authored by Maciej Wielgosz's avatar Maciej Wielgosz
Browse files

cleaning the basic cifar transfomer code

parent 20b84708
Branches
No related tags found
No related merge requests found
...@@ -49,16 +49,6 @@ def train(model, device, train_loader, optimizer, epoch): ...@@ -49,16 +49,6 @@ def train(model, device, train_loader, optimizer, epoch):
# log to wandb # log to wandb
wandb.log({"loss": loss.item()}) wandb.log({"loss": loss.item()})
wandb.log({"epoch": epoch}) wandb.log({"epoch": epoch})
# get all the parameters of the model
params = list(model.named_parameters())
# log the gradients
for name, param in params:
wandb.log({name + "_grad": wandb.Histogram(param.grad.cpu().numpy())})
# log the weights
for name, param in params:
wandb.log({name + "_weights": wandb.Histogram(param.detach().cpu().numpy())})
# compute the accuracy # compute the accuracy
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct = pred.eq(target.view_as(pred)).sum().item() correct = pred.eq(target.view_as(pred)).sum().item()
...@@ -198,7 +188,6 @@ class SelfAttentionParam(nn.Module): ...@@ -198,7 +188,6 @@ class SelfAttentionParam(nn.Module):
out = out.view(batch_size, num_embeddings, -1) out = out.view(batch_size, num_embeddings, -1)
return out return out
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, embedd_size, heads=8) -> None: def __init__(self, embedd_size, heads=8) -> None:
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
...@@ -215,12 +204,8 @@ class MultiHeadAttention(nn.Module): ...@@ -215,12 +204,8 @@ class MultiHeadAttention(nn.Module):
class MyTransformerLayer(nn.Module): class MyTransformerLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.1, batch_first=False): def __init__(self, d_model, nhead, dropout=0.1):
super(MyTransformerLayer, self).__init__() super(MyTransformerLayer, self).__init__()
# self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
print("d_model", d_model)
print("nhead", nhead)
self.self_attn = MultiHeadAttention(embedd_size=d_model, heads=nhead) self.self_attn = MultiHeadAttention(embedd_size=d_model, heads=nhead)
self.linear1 = nn.Linear(d_model, d_model) self.linear1 = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -230,9 +215,7 @@ class MyTransformerLayer(nn.Module): ...@@ -230,9 +215,7 @@ class MyTransformerLayer(nn.Module):
self.dropout1 = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout)
def forward(self, src, src_mask=None, src_key_padding_mask=None): def forward(self, src):
# src2 = self.self_attn(src, src, src, attn_mask=src_mask,
# key_padding_mask=src_key_padding_mask)[0]
src2 = self.self_attn(src) src2 = self.self_attn(src)
src = src + self.dropout1(src2) src = src + self.dropout1(src2)
src = self.norm1(src) src = self.norm1(src)
...@@ -245,7 +228,7 @@ class MyTransformerLayer(nn.Module): ...@@ -245,7 +228,7 @@ class MyTransformerLayer(nn.Module):
class PthBasedTransformer(nn.Module): class PthBasedTransformer(nn.Module):
def __init__(self, embedding_size=64) -> None: def __init__(self, embedding_size=64) -> None:
super().__init__() super().__init__()
self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True) self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True)
# self.self_attention = TransformerEncoderLayer( # self.self_attention = TransformerEncoderLayer(
# d_model=embedding_size, # d_model=embedding_size,
# nhead=16, # nhead=16,
...@@ -254,7 +237,7 @@ class PthBasedTransformer(nn.Module): ...@@ -254,7 +237,7 @@ class PthBasedTransformer(nn.Module):
# batch_first=True # batch_first=True
# ) # )
self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3, batch_first=True) self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3)
self.fc = nn.Linear(embedding_size, 10) self.fc = nn.Linear(embedding_size, 10)
def forward(self, x): def forward(self, x):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment