diff --git a/cifar_example/cifar_example_transformer.py b/cifar_example/cifar_example_transformer.py index 505317359038572f8e8a4f8015d372513e061a3c..45f47524c462b5d53f682db53bb1b9b4fab09def 100644 --- a/cifar_example/cifar_example_transformer.py +++ b/cifar_example/cifar_example_transformer.py @@ -8,6 +8,8 @@ import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import datasets, transforms +import wandb + # import resnet18 from trochvision from torchvision.models import resnet18 @@ -21,10 +23,10 @@ train_data = CIFAR10(root='./data', train=True, download=True, transform=transfo test_data = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor()) # get the train loader -train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=True) +train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True) # get the test loader -test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) +test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False) criterion = torch.nn.CrossEntropyLoss() # train the model @@ -42,6 +44,19 @@ def train(model, device, train_loader, optimizer, epoch): epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) + # log to wandb + wandb.log({"loss": loss.item()}) + wandb.log({"epoch": epoch}) + + # compute the accuracy + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct = pred.eq(target.view_as(pred)).sum().item() + accuracy = correct / len(data) + + # log to wandb + wandb.log({"accuracy": accuracy}) + + # test the model def test(model, device, test_loader): model.eval() @@ -84,22 +99,22 @@ class MyModel(nn.Module): #### define my transformer model +# define embedding class class Embedding(nn.Module): - def __init__(self, patch_size, in_channels, out_channels, device='cpu', return_patches=False, extra_token=False): + def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False): super(Embedding, self).__init__() self.patch_size = patch_size self.in_channels = in_channels self.out_channels = out_channels - self.device = device self.return_patches = return_patches self.classify = extra_token self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(out_channels) - self.extra_token = None # initialize extra_token tensor + def get_patches(self, x, patch_size=8): # get the patches - patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size) + patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size).to(x.device) return patches @@ -121,12 +136,6 @@ class Embedding(nn.Module): patches = self.get_patches(x, patch_size=self.patch_size) # flatten the patches patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size) - # add extra embedding token if classification is needed - if self.classify: - self.extra_token = torch.rand(1, self.in_channels, self.patch_size, self.patch_size) - self.extra_token = self.extra_token.to(x.device) # move extra_token to the same device as x - patches = torch.cat((self.extra_token, patches), dim=0) - # get the embedding embedding = self.conv(patches) # flatten the embedding @@ -135,14 +144,22 @@ class Embedding(nn.Module): embedding = self.norm(embedding) # add the positional encoding pos_encoding = self.get_pos_encoding(self.out_channels, embedding.shape[0]) - pos_encoding = pos_encoding.to(x.device) - embedding = embedding + pos_encoding + embedding = embedding + pos_encoding.to(x.device) + + # reshape the embedding + embedding = embedding.reshape(x.shape[0], -1, self.out_channels) + + if self.classify: + # add the classification token + classification_token = torch.rand(x.shape[0], 1, self.out_channels).to(x.device) + embedding = torch.cat((classification_token, embedding), dim=1) if self.return_patches: return embedding, patches else: return embedding + # define transformer class class SelfAttention(nn.Module): def __init__(self, embed_dim): @@ -178,26 +195,28 @@ from torch.nn import TransformerEncoderLayer class PthBasedTransformer(nn.Module): - def __init__(self) -> None: + def __init__(self, embedding_size=64) -> None: super().__init__() - self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=8, return_patches=True, extra_token=True) - self.self_attention = TransformerEncoderLayer(d_model=8, nhead=8) - self.fc = nn.Linear(8, 10) + self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True) + self.self_attention = TransformerEncoderLayer( + d_model=embedding_size, + nhead=16, + dim_feedforward=embedding_size*4, + dropout=0.3 + ) + self.fc = nn.Linear(embedding_size, 10) - - def forward(self, x): embedding, patches = self.embedding(x) context = self.self_attention(embedding) - - # get the first token context = context[:, 0, :] + # context = context.mean(dim=1) + # get the classification context = self.fc(context) - - + return context @@ -260,10 +279,29 @@ def main(train_model=False, model_type="resnet"): if __name__ == '__main__': train_model = True - model_type = "cnn" + # model_type = "cnn" # model_type = "resnet" model_type = "pth_transformer" + # Create a config object for wandb + config = { + 'model_type': model_type, + 'batch_size': 64, + 'test_batch_size': 1000, + 'epochs': 10, + 'lr': 0.01, + 'momentum': 0.5, + 'no_cuda': False, + 'seed': 1, + 'log_interval': 10, + 'patch_size': 8, + 'in_channels': 3, + 'out_channels': 64, + 'extra_token': True + } + + # add model type to wandb + wandb.init(project="cifar10_example_transformer", entity="maciej-wielgosz-nibio", config=config) main( train_model=train_model,