diff --git a/cifar_example/cifar_example_transformer.py b/cifar_example/cifar_example_transformer.py index f59818a97dde2841deabcc74b0c55618531c1bb2..6e77b5d024ae550c4a0bcec1fe310e6a6cf2231a 100644 --- a/cifar_example/cifar_example_transformer.py +++ b/cifar_example/cifar_example_transformer.py @@ -49,16 +49,6 @@ def train(model, device, train_loader, optimizer, epoch): # log to wandb wandb.log({"loss": loss.item()}) wandb.log({"epoch": epoch}) - # get all the parameters of the model - params = list(model.named_parameters()) - # log the gradients - for name, param in params: - wandb.log({name + "_grad": wandb.Histogram(param.grad.cpu().numpy())}) - - # log the weights - for name, param in params: - wandb.log({name + "_weights": wandb.Histogram(param.detach().cpu().numpy())}) - # compute the accuracy pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct = pred.eq(target.view_as(pred)).sum().item() @@ -198,7 +188,6 @@ class SelfAttentionParam(nn.Module): out = out.view(batch_size, num_embeddings, -1) return out - class MultiHeadAttention(nn.Module): def __init__(self, embedd_size, heads=8) -> None: super(MultiHeadAttention, self).__init__() @@ -215,12 +204,8 @@ class MultiHeadAttention(nn.Module): class MyTransformerLayer(nn.Module): - def __init__(self, d_model, nhead, dropout=0.1, batch_first=False): + def __init__(self, d_model, nhead, dropout=0.1): super(MyTransformerLayer, self).__init__() - # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) - print("d_model", d_model) - print("nhead", nhead) - self.self_attn = MultiHeadAttention(embedd_size=d_model, heads=nhead) self.linear1 = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) @@ -230,9 +215,7 @@ class MyTransformerLayer(nn.Module): self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) - def forward(self, src, src_mask=None, src_key_padding_mask=None): - # src2 = self.self_attn(src, src, src, attn_mask=src_mask, - # key_padding_mask=src_key_padding_mask)[0] + def forward(self, src): src2 = self.self_attn(src) src = src + self.dropout1(src2) src = self.norm1(src) @@ -245,7 +228,7 @@ class MyTransformerLayer(nn.Module): class PthBasedTransformer(nn.Module): def __init__(self, embedding_size=64) -> None: super().__init__() - self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True) + self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True) # self.self_attention = TransformerEncoderLayer( # d_model=embedding_size, # nhead=16, @@ -254,7 +237,7 @@ class PthBasedTransformer(nn.Module): # batch_first=True # ) - self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3, batch_first=True) + self.self_attention = MyTransformerLayer(d_model=embedding_size, nhead=16, dropout=0.3) self.fc = nn.Linear(embedding_size, 10) def forward(self, x):