diff --git a/cifar_example/__init__.py b/cifar_example/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cifar_example/cifar10_lightning_ver_2.py b/cifar_example/cifar10_lightning_ver_2.py new file mode 100644 index 0000000000000000000000000000000000000000..c12ca99444bdb93d2d58a82c7f95ed600fcf2533 --- /dev/null +++ b/cifar_example/cifar10_lightning_ver_2.py @@ -0,0 +1,200 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import pytorch_lightning as pl +from torch.utils.data import DataLoader, random_split +from torchvision.datasets import CIFAR10 +from torchvision import transforms +from pytorch_lightning.callbacks.progress import TQDMProgressBar +from pytorch_lightning.loggers import CSVLogger, WandbLogger + + +from torchvision.models import resnet18 + + +# create datamodule for cifar10 +class CIFAR10DataModule(pl.LightningDataModule): + def __init__(self, data_dir: str = "./", batch_size: int = 32): + super().__init__() + self.data_dir = data_dir + self.batch_size = batch_size + self.transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + def prepare_data(self): + # download + CIFAR10(self.data_dir, train=True, download=True) + CIFAR10(self.data_dir, train=False, download=True) + + def setup(self, stage=None): + # Assign train/val datasets for use in dataloaders + if stage == "fit" or stage is None: + cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform) + self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [45000, 5000]) + + # Assign test dataset for use in dataloader(s) + if stage == "test" or stage is None: + self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.transform) + + def train_dataloader(self): + return DataLoader(self.cifar10_train, batch_size=self.batch_size, num_workers=16) + + def val_dataloader(self): + return DataLoader(self.cifar10_val, batch_size=self.batch_size, num_workers=16) + + def test_dataloader(self): + return DataLoader(self.cifar10_test, batch_size=self.batch_size) + + +# import modules +from cifar_example.cifar_transformer_modules.my_transformer_layer import MyTransformerLayer +from cifar_example.cifar_transformer_modules.embedding import Embedding + +# create resnet model for cifar10 classification +class CIFAR10Model(pl.LightningModule): + def __init__(self): + super().__init__() + self.model = self.create_model() + + def forward(self, x): + return self.model(x) + + def create_model(self): + model = resnet18(pretrained=False, num_classes=10) + model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + model.maxpool = nn.Identity() + return model + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("train_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("val_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("test_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def accuracy(self, y_hat, y): + preds = torch.argmax(y_hat, dim=1) + return (preds == y).float().mean() + + +# create resnet model for cifar10 classification +# class CIFAR10Model(pl.LightningModule): + def __init__(self): + super().__init__() + self.embedding_size=64 + self.criterion = torch.nn.CrossEntropyLoss() + self.embedding = Embedding( + patch_size=16, + in_channels=3, + out_channels=self.embedding_size, + return_patches=False, + extra_token=True + ) + self.self_attention = MyTransformerLayer(d_model=self.embedding_size, nhead=16, dropout=0.3) + self.fc = nn.Linear(self.embedding_size, 10) + + def forward(self, x): + embedding = self.embedding(x) + context = self.self_attention(embedding) + # get the first token + context = context[:, 0, :] + + # context = context.mean(dim=1) + + # get the classification + context = self.fc(context) + + return context + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("train_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("val_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("test_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def accuracy(self, y_hat, y): + preds = torch.argmax(y_hat, dim=1) + return (preds == y).float().mean() + + +# create trainer +trainer = pl.Trainer( + gpus=1 if torch.cuda.is_available() else None, + max_epochs=10, + callbacks=[TQDMProgressBar(refresh_rate=20)], + logger=WandbLogger( + entity="maciej-wielgosz-nibio", + project="cifar10_example_transformer", + log_model=True, + save_code=True, + save_dir="wandb/" + ) +) + +# create datamodule +dm = CIFAR10DataModule(batch_size=64) + +# create model +model = CIFAR10Model() + +# train model +trainer.fit(model, dm) + +# test model +trainer.test(model, datamodule=dm) + +# save model +trainer.save_checkpoint("cifar10_model.ckpt") + +# print the metrics +print(trainer.logged_metrics) \ No newline at end of file diff --git a/cifar_example/cifar_example_transformer.py b/cifar_example/cifar_example_transformer.py index 6e77b5d024ae550c4a0bcec1fe310e6a6cf2231a..4df2190404af9f51b643c306ca99ca844000fcfd 100644 --- a/cifar_example/cifar_example_transformer.py +++ b/cifar_example/cifar_example_transformer.py @@ -228,7 +228,7 @@ class MyTransformerLayer(nn.Module): class PthBasedTransformer(nn.Module): def __init__(self, embedding_size=64) -> None: super().__init__() - self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True) + self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=embedding_size, return_patches=False, extra_token=True) # self.self_attention = TransformerEncoderLayer( # d_model=embedding_size, # nhead=16, @@ -241,7 +241,7 @@ class PthBasedTransformer(nn.Module): self.fc = nn.Linear(embedding_size, 10) def forward(self, x): - embedding, patches = self.embedding(x) + embedding = self.embedding(x) context = self.self_attention(embedding) # get the first token context = context[:, 0, :] diff --git a/cifar_example/cifar_transformer_modules/__init__.py b/cifar_example/cifar_transformer_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cifar_example/cifar_transformer_modules/embedding.py b/cifar_example/cifar_transformer_modules/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e6913370336709c84f41f131349a04f0367471e6 --- /dev/null +++ b/cifar_example/cifar_transformer_modules/embedding.py @@ -0,0 +1,71 @@ + +import torch +import torch.nn as nn +import math + + +class Embedding(nn.Module): + def __init__(self, patch_size, in_channels, out_channels, batch_size=1, return_patches=False, extra_token=False): + super(Embedding, self).__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.return_patches = return_patches + self.class_embedding = nn.Parameter(torch.randn(1, out_channels)) + self.classify = extra_token + self.patch_conv = nn.Conv2d( + in_channels, out_channels, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(out_channels) + self.proj = nn.Linear(out_channels, out_channels) + + def get_patches(self, x, patch_size=8): + # get the patches + patches = x.unfold(2, patch_size, patch_size).unfold( + 3, patch_size, patch_size).to(x.device) + + return patches + + def get_pos_encoding(self, d_emb, max_len): + pos = torch.arange(0, max_len).float().unsqueeze(1) + i = torch.arange(0, d_emb, 2).float() + + div = torch.exp(-i * math.log(10000) / d_emb) + + sin = torch.sin(pos * div) + cos = torch.cos(pos * div) + + pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb) + + return pos_encoding + + def forward(self, x): + + embedding = self.patch_conv(x) + + # flatten the embedding + embedding = embedding.reshape(x.shape[0], -1, self.out_channels) + + if self.classify: + class_embedding = self.class_embedding.repeat(x.shape[0], 1, 1) + embedding = torch.cat([class_embedding, embedding], dim=1) + + # normalize the embedding + embedding = self.norm(embedding) + + # project the embedding + embedding = self.proj(embedding) + + # add the positional encoding account for batch size + pos_encoding = self.get_pos_encoding( + self.out_channels, embedding.shape[1]).to(x.device) + + embedding = embedding + pos_encoding + + if self.return_patches: + patches = self.get_patches(x, self.patch_size) + patches = patches.reshape( + x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size) + + return embedding, patches + else: + return embedding diff --git a/cifar_example/cifar_transformer_modules/multi_head_attention.py b/cifar_example/cifar_transformer_modules/multi_head_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..004098f7f74c0fe4e4583d0108e0dce9b2c125b9 --- /dev/null +++ b/cifar_example/cifar_transformer_modules/multi_head_attention.py @@ -0,0 +1,37 @@ + +import torch +import torch.nn as nn + +class SelfAttentionParam(nn.Module): + def __init__(self, in_features, out_features) -> None: + super(SelfAttentionParam, self).__init__() + self.query = nn.Linear(in_features, out_features) + self.key = nn.Linear(in_features, out_features) + self.value = nn.Linear(in_features, out_features) + + def forward(self, x): + batch_size, num_embeddings, embedding_dim = x.size() + Q = self.query(x).view(batch_size, num_embeddings, -1) + K = self.key(x).view(batch_size, num_embeddings, -1) + V = self.value(x).view(batch_size, num_embeddings, -1) + # Q, K, V = [batch_size, num_embeddings, embedding_dim] + energy = torch.bmm(Q, K.permute(0, 2, 1)) + # energy = [batch_size, num_embeddings, num_embeddings] + attention = torch.softmax(energy, dim=-1) + out = torch.bmm(attention, V) + out = out.view(batch_size, num_embeddings, -1) + return out + +class MultiHeadAttention(nn.Module): + def __init__(self, embedd_size, heads=8) -> None: + super(MultiHeadAttention, self).__init__() + self.heads = heads + self.attention = nn.ModuleList([SelfAttentionParam(embedd_size, embedd_size) for _ in range(heads)]) + self.projection = nn.Linear(heads * embedd_size, embedd_size) + + def forward(self, x): + out = [self.attention[i](x) for i in range(self.heads)] + out = torch.cat(out, dim=2) + + out = self.projection(out) + return out \ No newline at end of file diff --git a/cifar_example/cifar_transformer_modules/my_transformer_layer.py b/cifar_example/cifar_transformer_modules/my_transformer_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..394d2d7b750bddd066b834c704ab46147ae315cb --- /dev/null +++ b/cifar_example/cifar_transformer_modules/my_transformer_layer.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from cifar_example.cifar_transformer_modules.multi_head_attention import MultiHeadAttention + +class MyTransformerLayer(nn.Module): + def __init__(self, d_model, nhead, dropout=0.1): + super(MyTransformerLayer, self).__init__() + self.self_attn = MultiHeadAttention(embedd_size=d_model, heads=nhead) + self.linear1 = nn.Linear(d_model, d_model) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_model, d_model) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, src): + src2 = self.self_attn(src) + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src \ No newline at end of file diff --git a/cifar_example/self_attention.ipynb b/cifar_example/self_attention.ipynb index a223d112a9797344c7c5b2934a9e370c50f0db99..0e3ed7bf954d9a0c09c3f2dc1ca6c7994862d611 100644 --- a/cifar_example/self_attention.ipynb +++ b/cifar_example/self_attention.ipynb @@ -404,14 +404,14 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 95, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([32, 5, 16])\n" + "torch.Size([32, 10, 16])\n" ] } ], diff --git a/cifar_example/vis.ipynb b/cifar_example/vis.ipynb index a3cd4854564118826b197d285a3f8fdab47a2e7e..92350a9a2d81ef8187e7c59eb954d11e5f80392a 100644 --- a/cifar_example/vis.ipynb +++ b/cifar_example/vis.ipynb @@ -2,27 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Files already downloaded and verified\n" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 640x480 with 1 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "import torch\n", "\n", @@ -62,121 +44,22 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "no: 8\n", - "patches 10: tensor([[[0.2314, 0.2431, 0.2471],\n", - " [0.1686, 0.1804, 0.1765],\n", - " [0.1961, 0.1882, 0.1686],\n", - " [0.2667, 0.2118, 0.1647]],\n", - "\n", - " [[0.0627, 0.0784, 0.0784],\n", - " [0.0000, 0.0000, 0.0000],\n", - " [0.0706, 0.0314, 0.0000],\n", - " [0.2000, 0.1059, 0.0314]],\n", - "\n", - " [[0.0980, 0.0941, 0.0824],\n", - " [0.0627, 0.0275, 0.0000],\n", - " [0.1922, 0.1059, 0.0314],\n", - " [0.3255, 0.1961, 0.0902]],\n", - "\n", - " [[0.1294, 0.0980, 0.0667],\n", - " [0.1490, 0.0784, 0.0157],\n", - " [0.3412, 0.2118, 0.0980],\n", - " [0.4157, 0.2471, 0.1098]]])\n" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 640x480 with 64 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from torch import nn\n", - "\n", - "patch_size = 4\n", - "\n", - "def get_patches(x, patch_size=8):\n", - " # get the batch size\n", - " patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)\n", - "\n", - " return patches\n", - "\n", - "# run the function on the first image\n", - "patches = get_patches(img.unsqueeze(0), patch_size=patch_size)\n", - "\n", - "# show the patches using patches.shape\n", - "no = int(32 / patch_size)\n", - "\n", - "print('no: ', no)\n", - "\n", - "fig, ax = plt.subplots(no, no)\n", - "for i in range(no):\n", - " for j in range(no):\n", - " ax[i, j].imshow(patches[0, :, i, j, :].permute(1, 2, 0))\n", - " ax[i, j].axis('off')\n", - "\n", - "print('patches 10: ', patches[0, :, 0, 0, :].permute(1, 2, 0))" - ] - }, - { - "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "patch size: 16\n", - "patch size: 16\n", - "patches: torch.Size([4, 3, 16, 16])\n", - "embedding: torch.Size([4, 8])\n" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 640x480 with 4 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAD5CAYAAAC6TTYBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFSElEQVR4nO3YMavWZRzH4fuoU4lgQnWmg6RH3EWiV5C1ODQ5CgotLeHgZI7nCNooKEJE4AsQx8aMbBGcjsIZBUfBNBT+vQAhHfR+hs91rc8fvr/p4cO9tizLMgCArD2rPgAAWC0xAABxYgAA4sQAAMSJAQCIEwMAECcGACBODABA3L53/XDj1taHvOMNHx18MXXv0P5/pu6NMcbTP9en7m1c+mPq3s71k1P3Pv70+dS9McZ4+OVvU/f2fP5o6t77cOyna1P3Xq6/nrp399TPU/fGGOOb33+Yurd59u+pezu3TkzdW4Xdr29O3Xvbf4eXAQCIEwMAECcGACBODABAnBgAgDgxAABxYgAA4sQAAMSJAQCIEwMAECcGACBODABAnBgAgDgxAABxYgAA4sQAAMSJAQCIEwMAECcGACBODABAnBgAgDgxAABxYgAA4sQAAMSJAQCIEwMAECcGACBubVmWZdVHAACr42UAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgbt+7frhxc/tD3vGGvc/2Tt17fOb61L0xxjh859zUvc3z96fufXbvwNS93SvHp+6NMcaFrV+n7p3+4sHUvffhyPbVqXuvDr2eurf77Y2pe2OMsfnL91P3Dl+8N3XvyY9fTd3795Nl6t4YY1z+7vbUvTNH//rf370MAECcGACAODEAAHFiAADixAAAxIkBAIgTAwAQJwYAIE4MAECcGACAODEAAHFiAADixAAAxIkBAIgTAwAQJwYAIE4MAECcGACAODEAAHFiAADixAAAxIkBAIgTAwAQJwYAIE4MAECcGACAODEAAHFry7Isqz4CAFgdLwMAECcGACBODABAnBgAgDgxAABxYgAA4sQAAMSJAQCIEwMAEPcfiCw/6Avfi0AAAAAASUVORK5CYII=", - "text/plain": [ - "<Figure size 640x480 with 4 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ + "import math\n", "from torch import nn\n", "\n", "# define embedding class\n", "class Embedding(nn.Module):\n", - " def __init__(self, patch_size, in_channels, out_channels, return_patches=False):\n", + " def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):\n", " super(Embedding, self).__init__()\n", " self.patch_size = patch_size\n", " self.in_channels = in_channels\n", " self.out_channels = out_channels\n", " self.return_patches = return_patches\n", + " self.classify = extra_token\n", " self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)\n", " self.norm = nn.LayerNorm(out_channels)\n", "\n", @@ -186,19 +69,39 @@ " patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)\n", "\n", " return patches\n", + "\n", + " def get_pos_encoding(self, d_emb, max_len):\n", + " pos = torch.arange(0, max_len).float().unsqueeze(1)\n", + " i = torch.arange(0, d_emb, 2).float()\n", + "\n", + " div = torch.exp(-i * math.log(10000) / d_emb)\n", + "\n", + " sin = torch.sin(pos * div)\n", + " cos = torch.cos(pos * div)\n", + "\n", + " pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n", + "\n", + " return pos_encoding\n", " \n", " def forward(self, x):\n", " # get the patches\n", - " print('patch size: ', self.patch_size)\n", " patches = self.get_patches(x, patch_size=self.patch_size)\n", " # flatten the patches\n", " patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)\n", + " # add extra embedding token if classification is needed\n", + " if self.classify:\n", + " extra_token = torch.rand(1, self.in_channels, self.patch_size, self.patch_size)\n", + " patches = torch.cat((extra_token, patches), dim=0)\n", + " \n", " # get the embedding\n", " embedding = self.conv(patches)\n", " # flatten the embedding\n", " embedding = embedding.reshape(-1, self.out_channels)\n", " # normalize the embedding\n", " embedding = self.norm(embedding)\n", + " # add the positional encoding\n", + " pos_encoding = self.get_pos_encoding(self.out_channels, embedding.shape[0])\n", + " embedding = embedding + pos_encoding\n", " \n", " if self.return_patches:\n", " return embedding, patches\n", @@ -208,7 +111,7 @@ "\n", "# use the embedding class\n", "patch_size = 16\n", - "embedding = Embedding(patch_size=patch_size, in_channels=3, out_channels=8, return_patches=True)\n", + "embedding = Embedding(patch_size=patch_size, in_channels=3, out_channels=8, return_patches=True, extra_token=True)\n", "\n", "embedding(img.unsqueeze(0))\n", "\n", @@ -227,157 +130,474 @@ " ax[i, j].axis('off')\n", "\n", "# plot the embeddings\n", + "# plot the embeddings\n", "no = int(32 / patch_size)\n", "fig, ax = plt.subplots(no, no)\n", "for i in range(no):\n", " for j in range(no):\n", - " ax[i, j].imshow(embedding[i * no + j, :].detach().numpy().reshape(1, -1))\n", + " ax[i, j].imshow(embedding.squeeze()[i * no + j, :].detach().numpy().reshape(1, -1))\n", " ax[i, j].axis('off')\n", "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class SelfAttention(nn.Module):\n", + " def __init__(self, embed_dim):\n", + " super().__init__()\n", + "\n", + " # Query, Key, Value weight matrices\n", + " self.qkv_linear = nn.Linear(embed_dim, embed_dim * 3)\n", + "\n", + " # Final output weight matrix\n", + " self.output_linear = nn.Linear(embed_dim, embed_dim)\n", + " \n", + " def forward(self, x):\n", + " batch_size, seq_len, embed_dim = x.size()\n", + "\n", + " # Create queries, keys, and values\n", + " qkv = self.qkv_linear(x)\n", + " q, k, v = torch.split(qkv, embed_dim, dim=-1)\n", "\n", - " " + " # Compute attention scores\n", + " scores = torch.matmul(q, k.transpose(-2, -1)) / (embed_dim ** 0.5)\n", + " attn = torch.softmax(scores, dim=-1)\n", + "\n", + " # Apply attention to values\n", + " weighted_values = torch.matmul(attn, v)\n", + "\n", + " # Apply final output weight matrix\n", + " output = self.output_linear(weighted_values)\n", + "\n", + " return output\n", + "\n", + "# get the context\n", + "context = SelfAttention(embed_dim=8)\n", + "\n", + "# get the embedding\n", + "embedding = Embedding(patch_size=16, in_channels=3, out_channels=8, return_patches=False, extra_token=True)\n", + "\n", + "# get the embedding\n", + "embedding = embedding(img.unsqueeze(0))\n", + "\n", + "# get the context\n", + "context = context(embedding)\n", + "\n", + "print('embedding: ', embedding.shape)\n", + "print('context: ', context.shape)\n", + "\n", + "\n", + "\n", + "\n", + "\n" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,\n", - " 0.0000e+00, 1.0000e+00],\n", - " [ 8.4147e-01, 5.4030e-01, 6.8156e-01, ..., 1.0000e+00,\n", - " 1.3335e-04, 1.0000e+00],\n", - " [ 9.0930e-01, -4.1615e-01, 9.9748e-01, ..., 1.0000e+00,\n", - " 2.6670e-04, 1.0000e+00],\n", - " ...,\n", - " [ 3.7961e-01, -9.2515e-01, -4.6453e-01, ..., 9.9985e-01,\n", - " 1.2935e-02, 9.9992e-01],\n", - " [-5.7338e-01, -8.1929e-01, -9.4349e-01, ..., 9.9985e-01,\n", - " 1.3068e-02, 9.9991e-01],\n", - " [-9.9921e-01, 3.9821e-02, -9.1628e-01, ..., 9.9985e-01,\n", - " 1.3201e-02, 9.9991e-01]])\n" - ] - }, - { - "data": { - "text/plain": [ - "(-0.5, 63.5, 99.5, -0.5)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 640x480 with 1 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "import math\n", + "from torch import nn\n", "import torch\n", + "from matplotlib import pyplot as plt\n", + "\n", + "# get cifar10 data\n", + "# import cifar10 dataset\n", + "from torchvision.datasets import CIFAR10\n", + "\n", + "# import torchvision transforms\n", + "from torchvision import transforms\n", + "\n", + "# set a seed\n", + "torch.manual_seed(0)\n", + "\n", + "# define embedding class\n", + "\n", + "\n", + "class Embedding(nn.Module):\n", + " def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):\n", + " super(Embedding, self).__init__()\n", + " self.patch_size = patch_size\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.return_patches = return_patches\n", + " self.class_embedding = nn.Parameter(torch.randn(1, out_channels))\n", + " self.classify = extra_token\n", + " self.conv = nn.Conv2d(in_channels, out_channels,\n", + " kernel_size=patch_size, stride=patch_size)\n", + " self.patch_conv = nn.Conv2d(\n", + " in_channels, out_channels, kernel_size=patch_size, stride=patch_size)\n", + " self.norm = nn.LayerNorm(out_channels)\n", + " self.proj = nn.Linear(out_channels, out_channels)\n", + "\n", + " def get_patches(self, x, patch_size=8):\n", + " # get the patches\n", + " patches = x.unfold(2, patch_size, patch_size).unfold(\n", + " 3, patch_size, patch_size).to(x.device)\n", + "\n", + " return patches\n", + "\n", + " def get_pos_encoding(self, d_emb, max_len):\n", + " pos = torch.arange(0, max_len).float().unsqueeze(1)\n", + " i = torch.arange(0, d_emb, 2).float()\n", + "\n", + " div = torch.exp(-i * math.log(10000) / d_emb)\n", + "\n", + " sin = torch.sin(pos * div)\n", + " cos = torch.cos(pos * div)\n", + "\n", + " pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n", + "\n", + " return pos_encoding\n", + "\n", + " def forward(self, x):\n", + "\n", + " # embedding = self.patch_conv(x)\n", + " embedding = self.get_patches(x, self.patch_size)\n", + "\n", + " # flatten the embedding\n", + " embedding = embedding.reshape(x.shape[0], -1, self.out_channels)\n", + "\n", + " if self.classify:\n", + " class_embedding = self.class_embedding.expand(x.shape[0], -1, -1)\n", + " embedding = torch.cat([class_embedding, embedding], dim=1)\n", + "\n", + " \n", + "\n", + " # normalize the embedding\n", + " embedding = self.norm(embedding)\n", + "\n", + " # project the embedding\n", + " embedding = self.proj(embedding)\n", + "\n", + " # add the positional encoding account for batch size\n", + " pos_encoding = self.get_pos_encoding(\n", + " self.out_channels, embedding.shape[1]).to(x.device)\n", "\n", + " embedding = embedding + pos_encoding\n", "\n", "\n", - "def sinusoidal_encoding_table(n_position, d_hid, padding_idx=None):\n", - " '''Generate sinusoidal position encoding table'''\n", - " encoding_table = torch.zeros(n_position, d_hid)\n", - " position = torch.arange(0, n_position).unsqueeze(1)\n", - " div_term = torch.exp(torch.arange(0, d_hid, 2) * -(math.log(10000.0) / d_hid))\n", - " encoding_table[:, 0::2] = torch.sin(position * div_term)\n", - " encoding_table[:, 1::2] = torch.cos(position * div_term)\n", - " if padding_idx is not None:\n", - " encoding_table[padding_idx] = 0.\n", - " return encoding_table\n", + " if self.return_patches:\n", + " patches = self.get_patches(x, self.patch_size)\n", + " patches = patches.reshape(\n", + " x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size)\n", + "\n", + " return embedding, patches\n", + " else:\n", + " return embedding\n", + "\n", + "\n", + "# use the embedding class\n", + "patch_size = 16\n", + "embedding = Embedding(patch_size=patch_size, in_channels=3, out_channels=8, return_patches=True, extra_token=True)\n", + "\n", + "\n", + "# get the training data\n", + "train_data = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())\n", + "# get data loader\n", + "train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True)\n", + "\n", + "# take a batch of images\n", + "x, y = next(iter(train_loader))\n", + "\n", + "# run the embedding on the images\n", + "embedding, patches = embedding(x)\n", + "\n", + "print('embedding shape: ', embedding.shape)\n", + "print('patches shape: ', patches.shape)\n", + "# plot the patches\n", + "no = int(32 / patch_size)\n", + "fig, ax = plt.subplots(no, no)\n", + "for i in range(no):\n", + " for j in range(no):\n", + " ax[i, j].imshow(patches[0, i * no + j, :].permute(1, 2, 0))\n", + " ax[i, j].axis('off')\n", + "\n", + "# plot the embeddings\n", + "# plot the embeddings\n", + "no = int(32 / patch_size)\n", + "fig, ax = plt.subplots(no, no)\n", + "for i in range(no):\n", + " for j in range(no):\n", + " ax[i, j].imshow(embedding.squeeze()[i * no + j, :].detach().numpy().reshape(1, -1))\n", + " ax[i, j].axis('off')\n", + "\n", "\n", - "seq_len = 100\n", - "embedding_dim = 64\n", "\n", - "pos_encoding = sinusoidal_encoding_table(seq_len, embedding_dim)\n", "\n", - "print(pos_encoding)\n", "\n", - "# plot the position encoding\n", - "fig, ax = plt.subplots(1, 1)\n", - "ax.imshow(pos_encoding.detach().numpy())\n", - "ax.axis('off')\n", "\n" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[ 0.0000, 0.0000, 0.0000, ..., 1.0000, 1.0000, 1.0000],\n", - " [ 0.8415, 0.6816, 0.5332, ..., 1.0000, 1.0000, 1.0000],\n", - " [ 0.9093, 0.9975, 0.9021, ..., 1.0000, 1.0000, 1.0000],\n", - " ...,\n", - " [ 0.3796, -0.4645, -0.9086, ..., 0.9997, 0.9999, 0.9999],\n", - " [-0.5734, -0.9435, -0.9914, ..., 0.9997, 0.9998, 0.9999],\n", - " [-0.9992, -0.9163, -0.7687, ..., 0.9997, 0.9998, 0.9999]]])\n" + "Files already downloaded and verified\n", + "Files already downloaded and verified\n", + "epoch: 0 batch: 0 loss: 2.338132619857788\n", + "epoch: 0 batch: 100 loss: 2.208583354949951\n", + "epoch: 0 batch: 200 loss: 1.97751784324646\n", + "epoch: 0 batch: 300 loss: 2.0264248847961426\n", + "epoch: 0 batch: 400 loss: 2.054466724395752\n", + "epoch: 0 batch: 500 loss: 1.817784309387207\n", + "epoch: 0 batch: 600 loss: 1.836090087890625\n", + "epoch: 0 batch: 700 loss: 1.5311310291290283\n" ] }, { - "data": { - "text/plain": [ - "(-0.5, 63.5, 99.5, -0.5)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "<Figure size 640x480 with 1 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 217\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39maccuracy: \u001b[39m\u001b[39m'\u001b[39m, accuracy \u001b[39m/\u001b[39m \u001b[39mlen\u001b[39m(test_loader))\n\u001b[1;32m 216\u001b[0m \u001b[39m# train the model\u001b[39;00m\n\u001b[0;32m--> 217\u001b[0m train(model\u001b[39m.\u001b[39;49mcuda(), train_loader, criterion, optimizer, epochs\u001b[39m=\u001b[39;49m\u001b[39m10\u001b[39;49m)\n\u001b[1;32m 218\u001b[0m \u001b[39m# test the model\u001b[39;00m\n\u001b[1;32m 219\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m'\u001b[39m)\n", + "Cell \u001b[0;32mIn[1], line 176\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, train_loader, criterion, optimizer, epochs)\u001b[0m\n\u001b[1;32m 174\u001b[0m loss\u001b[39m.\u001b[39mbackward()\n\u001b[1;32m 175\u001b[0m \u001b[39m# update the weights\u001b[39;00m\n\u001b[0;32m--> 176\u001b[0m optimizer\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 177\u001b[0m \u001b[39m# print the loss\u001b[39;00m\n\u001b[1;32m 178\u001b[0m \u001b[39mif\u001b[39;00m i \u001b[39m%\u001b[39m \u001b[39m100\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/optim/optimizer.py:140\u001b[0m, in \u001b[0;36mOptimizer._hook_for_profile.<locals>.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 138\u001b[0m profile_name \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mOptimizer.step#\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.step\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(obj\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m)\n\u001b[1;32m 139\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mautograd\u001b[39m.\u001b[39mprofiler\u001b[39m.\u001b[39mrecord_function(profile_name):\n\u001b[0;32m--> 140\u001b[0m out \u001b[39m=\u001b[39m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 141\u001b[0m obj\u001b[39m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 142\u001b[0m \u001b[39mreturn\u001b[39;00m out\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/optim/optimizer.py:23\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 22\u001b[0m torch\u001b[39m.\u001b[39mset_grad_enabled(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdefaults[\u001b[39m'\u001b[39m\u001b[39mdifferentiable\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[0;32m---> 23\u001b[0m ret \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 24\u001b[0m \u001b[39mfinally\u001b[39;00m:\n\u001b[1;32m 25\u001b[0m torch\u001b[39m.\u001b[39mset_grad_enabled(prev_grad)\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/optim/adam.py:226\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure, grad_scaler)\u001b[0m\n\u001b[1;32m 223\u001b[0m state[\u001b[39m'\u001b[39m\u001b[39mmax_exp_avg_sq\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mzeros_like(p, memory_format\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mpreserve_format)\n\u001b[1;32m 225\u001b[0m exp_avgs\u001b[39m.\u001b[39mappend(state[\u001b[39m'\u001b[39m\u001b[39mexp_avg\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[0;32m--> 226\u001b[0m exp_avg_sqs\u001b[39m.\u001b[39;49mappend(state[\u001b[39m'\u001b[39m\u001b[39mexp_avg_sq\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[1;32m 228\u001b[0m \u001b[39mif\u001b[39;00m group[\u001b[39m'\u001b[39m\u001b[39mamsgrad\u001b[39m\u001b[39m'\u001b[39m]:\n\u001b[1;32m 229\u001b[0m max_exp_avg_sqs\u001b[39m.\u001b[39mappend(state[\u001b[39m'\u001b[39m\u001b[39mmax_exp_avg_sq\u001b[39m\u001b[39m'\u001b[39m])\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] } ], "source": [ - "def get_pos_encoding(max_len, d_emb):\n", - " pos = torch.arange(0, max_len).float().unsqueeze(1)\n", - " i = torch.arange(0, d_emb, 2).float()\n", + "# create a transformer class model\n", + "from torch import nn\n", + "import torch\n", + "from torch.nn import functional as F\n", + "import math\n", "\n", - " div = torch.exp(-i * math.log(10000) / d_emb)\n", + "from torchvision.datasets import CIFAR10\n", + "# import torchvision transforms\n", + "from torchvision import transforms\n", "\n", - " sin = torch.sin(pos * div)\n", - " cos = torch.cos(pos * div)\n", + "class Embedding(nn.Module):\n", + " def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):\n", + " super(Embedding, self).__init__()\n", + " self.patch_size = patch_size\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.return_patches = return_patches\n", + " self.class_embedding = nn.Parameter(torch.randn(1, out_channels))\n", + " self.classify = extra_token\n", + " self.patch_conv = nn.Conv2d(\n", + " in_channels, out_channels, kernel_size=patch_size, stride=patch_size)\n", + " self.norm = nn.LayerNorm(out_channels)\n", + " self.proj = nn.Linear(out_channels, out_channels)\n", + "\n", + " def get_patches(self, x, patch_size=8):\n", + " # get the patches\n", + " patches = x.unfold(2, patch_size, patch_size).unfold(\n", + " 3, patch_size, patch_size).to(x.device)\n", + "\n", + " return patches\n", + "\n", + " def get_pos_encoding(self, d_emb, max_len):\n", + " pos = torch.arange(0, max_len).float().unsqueeze(1)\n", + " i = torch.arange(0, d_emb, 2).float()\n", "\n", - " pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n", + " div = torch.exp(-i * math.log(10000) / d_emb)\n", "\n", - " return pos_encoding\n", + " sin = torch.sin(pos * div)\n", + " cos = torch.cos(pos * div)\n", "\n", - "seq_len = 100\n", - "embedding_dim = 64\n", + " pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n", "\n", - "pos_encoding = get_pos_encoding(seq_len, embedding_dim)\n", + " return pos_encoding\n", + "\n", + " def forward(self, x):\n", + "\n", + " embedding = self.patch_conv(x)\n", + "\n", + " # flatten the embedding\n", + " embedding = embedding.reshape(x.shape[0], -1, self.out_channels)\n", "\n", - "print(pos_encoding)\n", + " if self.classify:\n", + " class_embedding = self.class_embedding.repeat(x.shape[0], 1, 1)\n", + " embedding = torch.cat([class_embedding, embedding], dim=1)\n", "\n", - "# plot the position encoding\n", - "fig, ax = plt.subplots(1, 1)\n", - "ax.imshow(pos_encoding.squeeze().detach().numpy())\n", - "ax.axis('off')\n" + " # normalize the embedding\n", + " embedding = self.norm(embedding)\n", + "\n", + " # project the embedding\n", + " embedding = self.proj(embedding)\n", + "\n", + " # add the positional encoding account for batch size\n", + " pos_encoding = self.get_pos_encoding(\n", + " self.out_channels, embedding.shape[1]).to(x.device)\n", + "\n", + " embedding = embedding + pos_encoding\n", + "\n", + " # apply the dropout\n", + " embedding = F.dropout(embedding, p=0.1, training=self.training)\n", + "\n", + "\n", + " if self.return_patches:\n", + " patches = self.get_patches(x, self.patch_size)\n", + " patches = patches.reshape(\n", + " x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size)\n", + "\n", + " return embedding, patches\n", + " else:\n", + " return embedding\n", + " \n", + "\n", + "\n", + "class MyTransformerLayer(nn.Module):\n", + " def __init__(self, d_model, nhead, dropout=0.1, batch_first=False):\n", + " super(MyTransformerLayer, self).__init__()\n", + " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)\n", + " self.linear1 = nn.Linear(d_model, d_model)\n", + " self.dropout = nn.Dropout(dropout)\n", + " self.linear2 = nn.Linear(d_model, d_model)\n", + " self.norm1 = nn.LayerNorm(d_model)\n", + " self.norm2 = nn.LayerNorm(d_model)\n", + " self.dropout1 = nn.Dropout(dropout)\n", + " self.dropout2 = nn.Dropout(dropout)\n", + "\n", + " def forward(self, src, src_mask=None, src_key_padding_mask=None):\n", + " src2 = self.self_attn(src, src, src, attn_mask=src_mask,\n", + " key_padding_mask=src_key_padding_mask)[0]\n", + " src = src + self.dropout1(src2)\n", + " src = self.norm1(src)\n", + " src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))\n", + " src = src + self.dropout2(src2)\n", + " src = self.norm2(src)\n", + " return src\n", + " \n", + "# define the transformer class\n", + "class MyTransformer(nn.Module):\n", + " def __init__(self) -> None:\n", + " super(MyTransformer, self).__init__()\n", + "\n", + " # define the embedding\n", + " self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=64, return_patches=False, extra_token=True)\n", + "\n", + " # self.encoder = nn.TransformerEncoderLayer(d_model=64, nhead=8, batch_first=True)\n", + " self.encoder = MyTransformerLayer(d_model=64, nhead=8, dropout=0.1, batch_first=True)\n", + "\n", + " # create the linear layer for all the outputs from the transformer\n", + " self.fc = nn.Linear(64, 10)\n", + "\n", + "\n", + " def forward(self, x):\n", + " # get the embedding\n", + " embedding = self.embedding(x)\n", + " # apply the transformer\n", + " out = self.encoder(embedding)\n", + " # out = self.fc_out_trans(out)\n", + "\n", + " # get the classification token\n", + " classification_token = out[:, 0, :]\n", + "\n", + " # out = out.mean(dim=1)\n", + " # pass the output through the linear layer\n", + " # out = self.fc(out)\n", + " out = self.fc(classification_token)\n", + "\n", + "\n", + " return out \n", + "\n", + "\n", + "# get the data \n", + "train_data = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())\n", + "test_data = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())\n", + "\n", + "# get data loader\n", + "train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)\n", + "\n", + "# create the model\n", + "model = MyTransformer()\n", + "\n", + "# define the loss function\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "# define the optimizer\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "\n", + "# defing the training loop for gpu\n", + "def train(model, train_loader, criterion, optimizer, epochs=10):\n", + " # set the model to train mode\n", + " model.train()\n", + " # loop through the epochs\n", + " for epoch in range(epochs):\n", + " # loop through the batches\n", + " for i, (x, y) in enumerate(train_loader):\n", + " # move the data to the gpu\n", + " x = x.cuda()\n", + " y = y.cuda()\n", + " # zero the gradients\n", + " optimizer.zero_grad()\n", + " # get the output\n", + " out = model(x)\n", + " # calculate the loss\n", + " loss = criterion(out, y)\n", + " # calculate the gradients\n", + " loss.backward()\n", + " # update the weights\n", + " optimizer.step()\n", + " # print the loss\n", + " if i % 100 == 0:\n", + " print('epoch: ', epoch, 'batch: ', i, 'loss: ', loss.item())\n", + "\n", + " # get parameters\n", + " for name, param in model.named_parameters():\n", + " if param.requires_grad:\n", + " # pass\n", + " \n", + " # get the names of the layers which do not change\n", + " if param.grad.abs().mean() == 0:\n", + " print(\"layer does not change\")\n", + " print(name)\n", + " \n", + "\n", + "# defing the testing loop for gpu\n", + "def test(model, test_loader, criterion):\n", + " # set the model to eval mode\n", + " model.eval()\n", + " # initialize the loss\n", + " loss = 0\n", + " # initialize the accuracy\n", + " accuracy = 0\n", + " # loop through the batches\n", + " for i, (x, y) in enumerate(test_loader):\n", + " # move the data to the gpu\n", + " x = x.cuda()\n", + " y = y.cuda()\n", + " # get the output\n", + " out = model(x)\n", + " # calculate the loss\n", + " loss += criterion(out, y).item()\n", + " # calculate the accuracy\n", + " accuracy += (out.argmax(1) == y).float().mean().item()\n", + " # print the loss\n", + " print('loss: ', loss / len(test_loader))\n", + " # print the accuracy\n", + " print('accuracy: ', accuracy / len(test_loader))\n", + "\n", + "# train the model\n", + "train(model.cuda(), train_loader, criterion, optimizer, epochs=10)\n", + "# test the model\n", + "print('testing')\n", + "test(model.cuda(), test_loader, criterion)\n", + "\n" ] } ], diff --git a/pth_lighting_example/mnist.ipynb b/pth_lighting_example/mnist.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3b55ac56167001fbbbd7f31eb86a2f0c31d101fd --- /dev/null +++ b/pth_lighting_example/mnist.ipynb @@ -0,0 +1,496 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import pandas as pd\n", + "import seaborn as sn\n", + "import torch\n", + "from IPython.core.display import display\n", + "from pytorch_lightning import LightningModule, Trainer\n", + "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchmetrics import Accuracy\n", + "from torchvision import transforms\n", + "from torchvision.datasets import MNIST\n", + "\n", + "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", + "BATCH_SIZE = 256 if torch.cuda.is_available() else 64\n", + "\n", + "class LitMNIST(LightningModule):\n", + " def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):\n", + " super().__init__()\n", + "\n", + " # Set our init args as class attributes\n", + " self.data_dir = data_dir\n", + " self.hidden_size = hidden_size\n", + " self.learning_rate = learning_rate\n", + "\n", + " # Hardcode some dataset specific attributes\n", + " self.num_classes = 10\n", + " self.dims = (1, 28, 28)\n", + " channels, width, height = self.dims\n", + " self.transform = transforms.Compose(\n", + " [\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,)),\n", + " ]\n", + " )\n", + "\n", + " # Define PyTorch model\n", + " self.model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(channels * width * height, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, self.num_classes),\n", + " )\n", + "\n", + " self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", + " self.test_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", + "\n", + " def forward(self, x):\n", + " x = self.model(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " self.val_accuracy.update(preds, y)\n", + "\n", + " # Calling self.log will surface up scalars for you in TensorBoard\n", + " self.log(\"val_loss\", loss, prog_bar=True)\n", + " self.log(\"val_acc\", self.val_accuracy, prog_bar=True)\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " self.test_accuracy.update(preds, y)\n", + "\n", + " # Calling self.log will surface up scalars for you in TensorBoard\n", + " self.log(\"test_loss\", loss, prog_bar=True)\n", + " self.log(\"test_acc\", self.test_accuracy, prog_bar=True)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer\n", + "\n", + " ####################\n", + " # DATA RELATED HOOKS\n", + " ####################\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == \"fit\" or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == \"test\" or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)\n", + "\n", + "\n", + "model = LitMNIST()\n", + "trainer = Trainer(\n", + " accelerator=\"auto\",\n", + " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", + " max_epochs=3,\n", + " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", + " logger=CSVLogger(save_dir=\"logs/\"),\n", + ")\n", + "trainer.fit(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n", + "del metrics[\"step\"]\n", + "metrics.set_index(\"epoch\", inplace=True)\n", + "display(metrics.dropna(axis=1, how=\"all\").head())\n", + "sn.relplot(data=metrics, kind=\"line\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import pandas as pd\n", + "import seaborn as sn\n", + "import torch\n", + "from IPython.core.display import display\n", + "from pytorch_lightning import LightningModule, Trainer\n", + "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchmetrics import Accuracy\n", + "from torchvision import transforms\n", + "from torchvision.datasets import MNIST, CIFAR10\n", + "\n", + "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", + "BATCH_SIZE = 256 if torch.cuda.is_available() else 64\n", + "\n", + "class MNISTModel(LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.l1 = torch.nn.Linear(32 * 32, 10)\n", + "\n", + " def forward(self, x):\n", + " return torch.relu(self.l1(x.view(x.size(0), -1)))\n", + "\n", + " def accuracy(self, y_hat, y):\n", + " preds = torch.argmax(y_hat, dim=1)\n", + " return (preds == y).float().mean()\n", + "\n", + " def training_step(self, batch, batch_nb):\n", + " x, y = batch\n", + " loss = F.cross_entropy(self(x), y)\n", + " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " # log accuracy\n", + " y_hat = self(x)\n", + " self.log(\"train_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "\n", + " return loss\n", + " \n", + " def validation_step(self, batch, batch_nb):\n", + " x, y = batch\n", + " loss = F.cross_entropy(self(x), y)\n", + " self.log(\"val_loss\", loss, prog_bar=True)\n", + " # log accuracy\n", + " y_hat = self(x)\n", + " self.log(\"val_acc\", self.accuracy(y_hat, y), prog_bar=True)\n", + " return loss\n", + " \n", + " def test_step(self, batch, batch_nb):\n", + " x, y = batch\n", + " loss = F.cross_entropy(self(x), y)\n", + " self.log(\"test_loss\", loss, prog_bar=True)\n", + " # log accuracy\n", + " y_hat = self(x)\n", + " self.log(\"test_acc\", self.accuracy(y_hat, y), prog_bar=True)\n", + "\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.Adam(self.parameters(), lr=0.02)\n", + " \n", + "\n", + "# Init our model\n", + "mnist_model = MNISTModel()\n", + "\n", + "# Init DataLoader from MNIST Dataset\n", + "# train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())\n", + "# train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)\n", + "\n", + "# test_ds = MNIST(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())\n", + "# test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)\n", + "\n", + "# Init DataLoader from CIFAR-10 Dataset\n", + "train_ds = CIFAR10(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())\n", + "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)\n", + "\n", + "test_ds = CIFAR10(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())\n", + "test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)\n", + "\n", + "\n", + "\n", + "# Initialize a trainer\n", + "trainer = Trainer(\n", + " accelerator=\"auto\",\n", + " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", + " max_epochs=5,\n", + " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", + ")\n", + "\n", + "# Train the model ⚡\n", + "trainer.fit(mnist_model, train_loader)\n", + "\n", + "# Test the model\n", + "trainer.test(mnist_model, dataloaders=test_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'cifar_example'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[13], line 46\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[39mreturn\u001b[39;00m DataLoader(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcifar10_test, batch_size\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbatch_size)\n\u001b[1;32m 45\u001b[0m \u001b[39m# import modules\u001b[39;00m\n\u001b[0;32m---> 46\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mcifar_example\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcifar_transformer_modules\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmy_transformer_layer\u001b[39;00m \u001b[39mimport\u001b[39;00m MyTransformerLayer\n\u001b[1;32m 47\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mcifar_example\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcifar_transformer_modules\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39membedding\u001b[39;00m \u001b[39mimport\u001b[39;00m Embedding\n\u001b[1;32m 49\u001b[0m \u001b[39m# create resnet model for cifar10 classification\u001b[39;00m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cifar_example'" + ] + } + ], + "source": [ + "import pytorch_lightning as pl\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchvision.datasets import CIFAR10\n", + "from torchvision import transforms\n", + "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n", + "\n", + "from torchvision.models import resnet18\n", + "\n", + "\n", + "# create datamodule for cifar10\n", + "class CIFAR10DataModule(pl.LightningDataModule):\n", + " def __init__(self, data_dir: str = \"./\", batch_size: int = 32):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.batch_size = batch_size\n", + " self.transform = transforms.Compose(\n", + " [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n", + " )\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " CIFAR10(self.data_dir, train=True, download=True)\n", + " CIFAR10(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == \"fit\" or stage is None:\n", + " cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n", + " self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [45000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == \"test\" or stage is None:\n", + " self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.cifar10_train, batch_size=self.batch_size, num_workers=16)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.cifar10_val, batch_size=self.batch_size, num_workers=16)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.cifar10_test, batch_size=self.batch_size)\n", + " \n", + "\n", + "# # import modules\n", + "# from cifar_example.cifar_transformer_modules.my_transformer_layer import MyTransformerLayer\n", + "# from cifar_example.cifar_transformer_modules.embedding import Embedding\n", + "\n", + "# create resnet model for cifar10 classification\n", + "class CIFAR10Model(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.model = resnet18(pretrained=False, num_classes=10)\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " y_hat = self(x)\n", + " loss = F.cross_entropy(y_hat, y)\n", + " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " # log accuracy\n", + " self.log(\"train_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " y_hat = self(x)\n", + " loss = F.cross_entropy(y_hat, y)\n", + " self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " # log accuracy\n", + " self.log(\"val_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " return loss\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " y_hat = self(x)\n", + " loss = F.cross_entropy(y_hat, y)\n", + " self.log(\"test_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " # log accuracy\n", + " self.log(\"test_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.Adam(self.parameters(), lr=0.001)\n", + "\n", + " def accuracy(self, y_hat, y):\n", + " preds = torch.argmax(y_hat, dim=1)\n", + " return (preds == y).float().mean()\n", + " \n", + "\n", + "# # create resnet model for cifar10 classification\n", + "# class CIFAR10Model(pl.LightningModule):\n", + "# def __init__(self):\n", + "# super().__init__()\n", + "# self.embedding_size=64\n", + "# self.criterion = torch.nn.CrossEntropyLoss()\n", + "# self.embedding = Embedding(\n", + "# patch_size=8, \n", + "# in_channels=3, \n", + "# out_channels=self.embedding_size, \n", + "# return_patches=False, \n", + "# extra_token=True\n", + "# )\n", + "# self.self_attention = MyTransformerLayer(d_model=self.embedding_size, nhead=16, dropout=0.3)\n", + "# self.fc = nn.Linear(self.embedding_size, 10)\n", + " \n", + "# def forward(self, x):\n", + "# embedding = self.embedding(x)\n", + "# context = self.self_attention(embedding)\n", + "# # get the first token\n", + "# context = context[:, 0, :]\n", + "\n", + "# # context = context.mean(dim=1)\n", + "\n", + "# # get the classification\n", + "# context = self.fc(context)\n", + " \n", + "# return context\n", + "\n", + "# def training_step(self, batch, batch_idx):\n", + "# x, y = batch\n", + "# y_hat = self(x)\n", + "# loss = F.cross_entropy(y_hat, y)\n", + "# self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# # log accuracy\n", + "# self.log(\"train_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# return loss\n", + "\n", + "# def validation_step(self, batch, batch_idx):\n", + "# x, y = batch\n", + "# y_hat = self(x)\n", + "# loss = F.cross_entropy(y_hat, y)\n", + "# self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# # log accuracy\n", + "# self.log(\"val_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# return loss\n", + "\n", + "# def test_step(self, batch, batch_idx):\n", + "# x, y = batch\n", + "# y_hat = self(x)\n", + "# loss = F.cross_entropy(y_hat, y)\n", + "# self.log(\"test_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# # log accuracy\n", + "# self.log(\"test_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# return loss\n", + "\n", + "# def configure_optimizers(self):\n", + "# return torch.optim.Adam(self.parameters(), lr=0.001)\n", + "\n", + "# def accuracy(self, y_hat, y):\n", + "# preds = torch.argmax(y_hat, dim=1)\n", + "# return (preds == y).float().mean()\n", + " \n", + "\n", + "# create trainer\n", + "trainer = pl.Trainer(\n", + " gpus=1 if torch.cuda.is_available() else None,\n", + " max_epochs=5,\n", + " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", + " logger=CSVLogger(save_dir=\"logs/\")\n", + ")\n", + "\n", + "# create datamodule\n", + "dm = CIFAR10DataModule()\n", + "\n", + "# create model\n", + "model = CIFAR10Model()\n", + "\n", + "# train model\n", + "trainer.fit(model, dm)\n", + "\n", + "# test model\n", + "trainer.test(model, datamodule=dm)\n", + "\n", + "# save model\n", + "trainer.save_checkpoint(\"cifar10_model.ckpt\")\n", + "\n", + "# print the metrics\n", + "print(trainer.logged_metrics)\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}