diff --git a/cifar_example/__init__.py b/cifar_example/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cifar_example/cifar10_lightning_ver_2.py b/cifar_example/cifar10_lightning_ver_2.py new file mode 100644 index 0000000000000000000000000000000000000000..c12ca99444bdb93d2d58a82c7f95ed600fcf2533 --- /dev/null +++ b/cifar_example/cifar10_lightning_ver_2.py @@ -0,0 +1,200 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import pytorch_lightning as pl +from torch.utils.data import DataLoader, random_split +from torchvision.datasets import CIFAR10 +from torchvision import transforms +from pytorch_lightning.callbacks.progress import TQDMProgressBar +from pytorch_lightning.loggers import CSVLogger, WandbLogger + + +from torchvision.models import resnet18 + + +# create datamodule for cifar10 +class CIFAR10DataModule(pl.LightningDataModule): + def __init__(self, data_dir: str = "./", batch_size: int = 32): + super().__init__() + self.data_dir = data_dir + self.batch_size = batch_size + self.transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + def prepare_data(self): + # download + CIFAR10(self.data_dir, train=True, download=True) + CIFAR10(self.data_dir, train=False, download=True) + + def setup(self, stage=None): + # Assign train/val datasets for use in dataloaders + if stage == "fit" or stage is None: + cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform) + self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [45000, 5000]) + + # Assign test dataset for use in dataloader(s) + if stage == "test" or stage is None: + self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.transform) + + def train_dataloader(self): + return DataLoader(self.cifar10_train, batch_size=self.batch_size, num_workers=16) + + def val_dataloader(self): + return DataLoader(self.cifar10_val, batch_size=self.batch_size, num_workers=16) + + def test_dataloader(self): + return DataLoader(self.cifar10_test, batch_size=self.batch_size) + + +# import modules +from cifar_example.cifar_transformer_modules.my_transformer_layer import MyTransformerLayer +from cifar_example.cifar_transformer_modules.embedding import Embedding + +# create resnet model for cifar10 classification +class CIFAR10Model(pl.LightningModule): + def __init__(self): + super().__init__() + self.model = self.create_model() + + def forward(self, x): + return self.model(x) + + def create_model(self): + model = resnet18(pretrained=False, num_classes=10) + model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + model.maxpool = nn.Identity() + return model + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("train_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("val_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("test_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def accuracy(self, y_hat, y): + preds = torch.argmax(y_hat, dim=1) + return (preds == y).float().mean() + + +# create resnet model for cifar10 classification +# class CIFAR10Model(pl.LightningModule): + def __init__(self): + super().__init__() + self.embedding_size=64 + self.criterion = torch.nn.CrossEntropyLoss() + self.embedding = Embedding( + patch_size=16, + in_channels=3, + out_channels=self.embedding_size, + return_patches=False, + extra_token=True + ) + self.self_attention = MyTransformerLayer(d_model=self.embedding_size, nhead=16, dropout=0.3) + self.fc = nn.Linear(self.embedding_size, 10) + + def forward(self, x): + embedding = self.embedding(x) + context = self.self_attention(embedding) + # get the first token + context = context[:, 0, :] + + # context = context.mean(dim=1) + + # get the classification + context = self.fc(context) + + return context + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("train_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("val_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # log accuracy + self.log("test_acc", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def accuracy(self, y_hat, y): + preds = torch.argmax(y_hat, dim=1) + return (preds == y).float().mean() + + +# create trainer +trainer = pl.Trainer( + gpus=1 if torch.cuda.is_available() else None, + max_epochs=10, + callbacks=[TQDMProgressBar(refresh_rate=20)], + logger=WandbLogger( + entity="maciej-wielgosz-nibio", + project="cifar10_example_transformer", + log_model=True, + save_code=True, + save_dir="wandb/" + ) +) + +# create datamodule +dm = CIFAR10DataModule(batch_size=64) + +# create model +model = CIFAR10Model() + +# train model +trainer.fit(model, dm) + +# test model +trainer.test(model, datamodule=dm) + +# save model +trainer.save_checkpoint("cifar10_model.ckpt") + +# print the metrics +print(trainer.logged_metrics) \ No newline at end of file diff --git a/cifar_example/cifar_example_transformer.py b/cifar_example/cifar_example_transformer.py index 6e77b5d024ae550c4a0bcec1fe310e6a6cf2231a..4df2190404af9f51b643c306ca99ca844000fcfd 100644 --- a/cifar_example/cifar_example_transformer.py +++ b/cifar_example/cifar_example_transformer.py @@ -228,7 +228,7 @@ class MyTransformerLayer(nn.Module): class PthBasedTransformer(nn.Module): def __init__(self, embedding_size=64) -> None: super().__init__() - self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=embedding_size, return_patches=True, extra_token=True) + self.embedding = Embedding(patch_size=8, in_channels=3, out_channels=embedding_size, return_patches=False, extra_token=True) # self.self_attention = TransformerEncoderLayer( # d_model=embedding_size, # nhead=16, @@ -241,7 +241,7 @@ class PthBasedTransformer(nn.Module): self.fc = nn.Linear(embedding_size, 10) def forward(self, x): - embedding, patches = self.embedding(x) + embedding = self.embedding(x) context = self.self_attention(embedding) # get the first token context = context[:, 0, :] diff --git a/cifar_example/cifar_transformer_modules/__init__.py b/cifar_example/cifar_transformer_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cifar_example/cifar_transformer_modules/embedding.py b/cifar_example/cifar_transformer_modules/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e6913370336709c84f41f131349a04f0367471e6 --- /dev/null +++ b/cifar_example/cifar_transformer_modules/embedding.py @@ -0,0 +1,71 @@ + +import torch +import torch.nn as nn +import math + + +class Embedding(nn.Module): + def __init__(self, patch_size, in_channels, out_channels, batch_size=1, return_patches=False, extra_token=False): + super(Embedding, self).__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.return_patches = return_patches + self.class_embedding = nn.Parameter(torch.randn(1, out_channels)) + self.classify = extra_token + self.patch_conv = nn.Conv2d( + in_channels, out_channels, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(out_channels) + self.proj = nn.Linear(out_channels, out_channels) + + def get_patches(self, x, patch_size=8): + # get the patches + patches = x.unfold(2, patch_size, patch_size).unfold( + 3, patch_size, patch_size).to(x.device) + + return patches + + def get_pos_encoding(self, d_emb, max_len): + pos = torch.arange(0, max_len).float().unsqueeze(1) + i = torch.arange(0, d_emb, 2).float() + + div = torch.exp(-i * math.log(10000) / d_emb) + + sin = torch.sin(pos * div) + cos = torch.cos(pos * div) + + pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb) + + return pos_encoding + + def forward(self, x): + + embedding = self.patch_conv(x) + + # flatten the embedding + embedding = embedding.reshape(x.shape[0], -1, self.out_channels) + + if self.classify: + class_embedding = self.class_embedding.repeat(x.shape[0], 1, 1) + embedding = torch.cat([class_embedding, embedding], dim=1) + + # normalize the embedding + embedding = self.norm(embedding) + + # project the embedding + embedding = self.proj(embedding) + + # add the positional encoding account for batch size + pos_encoding = self.get_pos_encoding( + self.out_channels, embedding.shape[1]).to(x.device) + + embedding = embedding + pos_encoding + + if self.return_patches: + patches = self.get_patches(x, self.patch_size) + patches = patches.reshape( + x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size) + + return embedding, patches + else: + return embedding diff --git a/cifar_example/cifar_transformer_modules/multi_head_attention.py b/cifar_example/cifar_transformer_modules/multi_head_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..004098f7f74c0fe4e4583d0108e0dce9b2c125b9 --- /dev/null +++ b/cifar_example/cifar_transformer_modules/multi_head_attention.py @@ -0,0 +1,37 @@ + +import torch +import torch.nn as nn + +class SelfAttentionParam(nn.Module): + def __init__(self, in_features, out_features) -> None: + super(SelfAttentionParam, self).__init__() + self.query = nn.Linear(in_features, out_features) + self.key = nn.Linear(in_features, out_features) + self.value = nn.Linear(in_features, out_features) + + def forward(self, x): + batch_size, num_embeddings, embedding_dim = x.size() + Q = self.query(x).view(batch_size, num_embeddings, -1) + K = self.key(x).view(batch_size, num_embeddings, -1) + V = self.value(x).view(batch_size, num_embeddings, -1) + # Q, K, V = [batch_size, num_embeddings, embedding_dim] + energy = torch.bmm(Q, K.permute(0, 2, 1)) + # energy = [batch_size, num_embeddings, num_embeddings] + attention = torch.softmax(energy, dim=-1) + out = torch.bmm(attention, V) + out = out.view(batch_size, num_embeddings, -1) + return out + +class MultiHeadAttention(nn.Module): + def __init__(self, embedd_size, heads=8) -> None: + super(MultiHeadAttention, self).__init__() + self.heads = heads + self.attention = nn.ModuleList([SelfAttentionParam(embedd_size, embedd_size) for _ in range(heads)]) + self.projection = nn.Linear(heads * embedd_size, embedd_size) + + def forward(self, x): + out = [self.attention[i](x) for i in range(self.heads)] + out = torch.cat(out, dim=2) + + out = self.projection(out) + return out \ No newline at end of file diff --git a/cifar_example/cifar_transformer_modules/my_transformer_layer.py b/cifar_example/cifar_transformer_modules/my_transformer_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..394d2d7b750bddd066b834c704ab46147ae315cb --- /dev/null +++ b/cifar_example/cifar_transformer_modules/my_transformer_layer.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from cifar_example.cifar_transformer_modules.multi_head_attention import MultiHeadAttention + +class MyTransformerLayer(nn.Module): + def __init__(self, d_model, nhead, dropout=0.1): + super(MyTransformerLayer, self).__init__() + self.self_attn = MultiHeadAttention(embedd_size=d_model, heads=nhead) + self.linear1 = nn.Linear(d_model, d_model) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_model, d_model) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, src): + src2 = self.self_attn(src) + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src \ No newline at end of file diff --git a/cifar_example/self_attention.ipynb b/cifar_example/self_attention.ipynb index a223d112a9797344c7c5b2934a9e370c50f0db99..0e3ed7bf954d9a0c09c3f2dc1ca6c7994862d611 100644 --- a/cifar_example/self_attention.ipynb +++ b/cifar_example/self_attention.ipynb @@ -404,14 +404,14 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 95, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([32, 5, 16])\n" + "torch.Size([32, 10, 16])\n" ] } ], diff --git a/cifar_example/vis.ipynb b/cifar_example/vis.ipynb index a3cd4854564118826b197d285a3f8fdab47a2e7e..92350a9a2d81ef8187e7c59eb954d11e5f80392a 100644 --- a/cifar_example/vis.ipynb +++ b/cifar_example/vis.ipynb @@ -2,27 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Files already downloaded and verified\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAw70lEQVR4nO3de5DU9Znv8U/fp+fWw8wwNxiQi+IVckIUJyauEVZgqzwaqS1NUrWYtfTojtYqm03CVqLR3a1xTZ3EJEXwj3VlUxU0cSvo0droKgaobMANRAovCRGCAsIM17n19L1/5w/X2YyCfB+c4cuM71dVV8nM4zPf36X7md9096dDQRAEAgDgDAv7XgAA4OOJAQQA8IIBBADwggEEAPCCAQQA8IIBBADwggEEAPCCAQQA8CLqewHvVy6XdeDAAdXU1CgUCvleDgDAKAgCDQwMqK2tTeHwya9zzroBdODAAbW3t/teBgDgI9q3b5+mTp160u+P2QBatWqVvv3tb6u7u1vz5s3TD37wA1122WWn/P9qamokSfMvW6Bo1G15fX3HndeVCJedayVpUtw9qWjqpEpT78Z69/qGVJWpdzwcc66NJJKm3opETOXHe/ucawtFWzJUXSrlXBsuFUy9c/mcc202614rSRXJhKm+pJJzbSaTNvWuTdW4Fwfu65CkfN59n0eMD0cRw3lYXVVt6l1VabsvR2MVzrXZXN7UOwgZnikJ2/ZhPu++lmLg/hepbC6vb37/x8OP5yczJgPoJz/5iVasWKFHHnlECxYs0MMPP6zFixdr586dampq+tD/970/u0WjUecBZDkRI2Hbn/WiEfcHxHjM9sCciLnv/oq4+0CRpHjEvT6asPVWxHbaZAxrD4dtA6jCsPaw7bFTIRl+WSnbmluPZ8nwdG25ZDs+ln2owPa0cVjuxzMi2z6x3O+TxnM8WRE31cdi7vXWZxbGcgBFDGuxDKD3nOpplDF5EcJ3vvMd3Xrrrfryl7+sCy+8UI888ogqKyv1L//yL2Px4wAA49CoD6B8Pq9t27Zp0aJF//NDwmEtWrRImzdv/kB9LpdTf3//iBsAYOIb9QF05MgRlUolNTc3j/h6c3Ozuru7P1Df1dWlVCo1fOMFCADw8eD9fUArV65UX1/f8G3fvn2+lwQAOANG/UUIjY2NikQi6unpGfH1np4etbS0fKA+kUgokbC9IggAMP6N+hVQPB7X/PnztX79+uGvlctlrV+/Xh0dHaP94wAA49SYvAx7xYoVWr58uT71qU/psssu08MPP6x0Oq0vf/nLY/HjAADj0JgMoBtvvFGHDx/Wvffeq+7ubn3iE5/Qc88994EXJgAAPr5CQRDY3vk3xvr7+999RVx9vUIfkiH0x3qPHHHuX+/+hmVJ0owG9//h3BbDO8olnTP9w9+U+8cqEra/lgYl98MahGxvuhvK2t7JPZRxTwkolGxJFVHDO+kqorZTvVh0X0vE+AZA6/OeQ1n3dINi2XZ8GhsbnGvDtvdaq5BzP/bJqO3OmTMkCpRKRVPvykpb8kjIkDwSMrxJXJLk+DgoSUNZW9pHsWBIqoi6n7O5QlH/92e/Vl9fn2pra09a5/1VcACAjycGEADACwYQAMALBhAAwAsGEADACwYQAMALBhAAwAsGEADACwYQAMALBhAAwIsxyYIbDRXRkMJhx5gVQ6rJdEO0jiSd05xyrm2aXG/qnTTEfZzqs9XfL5PLOtdmC+5xKZIUGNcSTybdi4u2uJyg7L72VH2lqXex4L6WeMywjZJKJVO5InFDDEre/dhLUqHofjwrDeuQpGiV+36pMPYuhtzjicKBLeKpKNs5bkiEUnWV7TwcTA851xaKtige14dYSRro73OuzRfcTnCugAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABenL1ZcKGSwiG3/KaaGvfNOG/KJNM6GpIR59pY2ZbBNXgs71xbKtt+V8gMFZ1rw3FTa9XWVZvqo4aMr96+AVtvwxlcX2PL4Brod88ay2fdayUpk7VldgWGbLLqKveMQUkq5DPOteGS7SEjlnA/9qWSbZ9EDQFsuZytdzxmu1OEy+73t9zgcVNvldwzCRPuD1eSpGLZPSOvL+2eu5gvuvXlCggA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4MVZG8VTl4goEnabj0lD3EeqKmlax+TamHNtqVwy9bZUR6LGjA3HfSdJubIxAsWSfyMpGrjHfZRy7rEwkhRE3Lfz0KFeU+9Swf0IDQwNmXoPldxjmCSpOlnrXpyznYcRuR+fcMg9FkaSIokK59pM2hZlVRlz3yfRwLbubNZ2fDIF9yiesmxr6R103y+9Q7b78qAhsitbcL+vFUtE8QAAzmIMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAF2dtFlxjqkJRx5yvmph7TlpFhS1TLRxxz21KJm05c4Wie2ZXWSFT7yBwz7LKF23ZVKW8LW+qHLjXB8aMtCAad64dyKdNvUsl93NlyDH76j2uWVnvGUi778N3jtm2MxZ2X0vtoO08LHQfca7N9Nny9KY1znaubWqaauodqukz1eeOH3WuHRy0HZ++AfcsuCN9tizFt/a5b2cp4j4uyo7Ze1wBAQC8GPUB9K1vfUuhUGjE7fzzzx/tHwMAGOfG5E9wF110kV588cX/+SHG+H4AwMQ3JpMhGo2qpaVlLFoDACaIMXkO6M0331RbW5tmzpypL33pS9q7d+9Ja3O5nPr7+0fcAAAT36gPoAULFmjNmjV67rnntHr1au3Zs0ef/exnNTAwcML6rq4upVKp4Vt7e/toLwkAcBYa9QG0dOlS/fmf/7nmzp2rxYsX69///d/V29urn/70pyesX7lypfr6+oZv+/btG+0lAQDOQmP+6oC6ujqdd9552rVr1wm/n0gklEgkxnoZAICzzJi/D2hwcFC7d+9Wa2vrWP8oAMA4MuoD6Ctf+Yo2btyot956S7/61a/0+c9/XpFIRF/4whdG+0cBAMaxUf8T3P79+/WFL3xBR48e1eTJk/WZz3xGW7Zs0eTJk019WhorFY+6RaHUxovOfasr3aNbJClkiJGRbJE2ocA9AiWXscWUhA3RPQ01KVPvqqoKU31/n3scS6q21tR7IOt+fN5+x30dkjSYc4/iiduSdTSl0nbXi8bcI1beOtpr6p0L3LczFrKd46naGufaT1/4KVPv/oPuUVbBkHHdjTFTfW7I/XgODtp+70/E3NfS3uK+vyWpqanZuban3z0SqFgqa+9r+09ZN+oD6IknnhjtlgCACYgsOACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAF2P+cQyna1J1UomYW0ZVNN/r3DcRs21yZaLSuTaXseTGSYWye4ZdXd0kU+8gcM++ypdsv4cUCu6ZUJJUWV3tXHvgcM7Ue/fbfc61hwfc97ckDRnKpyfd89Qk6frPfsJUP7XVfR/+27Y/mHpv3tXtXFss5029o2H383Cg97Cp99Cg+7lSU2PLdlPJPUtRkioq3PvHK2znSmXIvXexZDvHp7W3OdfWHDvxh4qeSL5Q0iaHLDiugAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXpy1UTyTJ9WrIu62vMwx92iYcMi2yYND7vE6mbwtBiMaco/kGCqUTL0tv1lkCrZ4lbpJtab6fMk9juUP+w+Yeh/rd98vQTRu6h2JuO/F2grb8WmKuseaSFLFMffYmXNrW0y9D9a7b2dP7yFT79yQ+7n1yu9/b+odLpadawtVtnNWqWZbfdj9cSWVco/3kqSasvv9J5u3xYEF+X7n2nMmVxnW4fZYyBUQAMALBhAAwAsGEADACwYQAMALBhAAwAsGEADACwYQAMALBhAAwAsGEADACwYQAMALBhAAwIuzNguurqFRyUTMqXZSddK5bzjs1vM9vf3HnWsL6UFT73DJPT+sLPfcK0kKYu6Htrq6wtS7IFv9b//gnvGVzqVNvSsqEu61jtmC70lWuWd2TYrYcgC37eox1Rfz7mvPpWxZcJMnuR/PkGyZaoWie07jUD5j6p0ecs9IyxdtxydkzEdUyL00FjYUSwrC7pmRsajtHC/m3DMGA0Omo2stV0AAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAAL87aLDiFo5JjblsoZst3s0hUuPeuVJWpd9Qw/8Nh2+8KBUN2XCKZMvU+0j1gqh864p6nN7PeljOXc48aU4Uh202S5sya4lwbtixEUjFiO2f7DZmE0UifqXdN3P28bZg0y9R71rnTnGv37P21qffvfv+Oc2086p55JklBYMt1LBbdH0rD0bipdyzufq6Uy7bMyLIhxC4Ucn8Mcq3lCggA4IV5AG3atEnXXnut2traFAqF9NRTT434fhAEuvfee9Xa2qpkMqlFixbpzTffHK31AgAmCPMASqfTmjdvnlatWnXC7z/00EP6/ve/r0ceeUQvv/yyqqqqtHjxYmWztj9RAAAmNvNzQEuXLtXSpUtP+L0gCPTwww/rG9/4hq677jpJ0o9+9CM1Nzfrqaee0k033fTRVgsAmDBG9TmgPXv2qLu7W4sWLRr+WiqV0oIFC7R58+YT/j+5XE79/f0jbgCAiW9UB1B3d7ckqbm5ecTXm5ubh7/3fl1dXUqlUsO39vb20VwSAOAs5f1VcCtXrlRfX9/wbd++fb6XBAA4A0Z1ALW0vPtZ9D09Iz/vvqenZ/h775dIJFRbWzviBgCY+EZ1AM2YMUMtLS1av3798Nf6+/v18ssvq6OjYzR/FABgnDO/Cm5wcFC7du0a/veePXu0fft21dfXa9q0abr77rv1D//wDzr33HM1Y8YMffOb31RbW5uuv/760Vw3AGCcMw+grVu36nOf+9zwv1esWCFJWr58udasWaOvfvWrSqfTuu2229Tb26vPfOYzeu6551RRYYtYyWaLUuAWExEqZAydi6Z1pNPur8rLF2wXlMWw+z4ZHLLF3/Qb6qe0206DoGhby/RG97iPWW22iJqhrHvvKefNM/WOB+7vXTveVzD1TtY1mOp1NOJc2t7Samrdm0471848/1xT79pJ7vFHtZMuMPU+ftj9PDzeZ4snihniiSQpHCScawvlkqm3JV2nVLA9voXd7z4KgmDUa80D6KqrrvrQ5qFQSA888IAeeOABa2sAwMeI91fBAQA+nhhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAAL8xRPGdKKVRSKeQ2H4OSe/6RJc9IkpIVSefa6hr33CtJOnDYPcNuz/7Dpt7RmPt2xnsOmHpne2xrObfJPd9t4VW2rLHd7xxzrq2ZMtnUu7HhxB8hciKHDvecuuiP1NUZs8bK7vswHnbPjZOkQ4ffca6NVvSaeh/uPehc+87BQVPvWMz9/lZXawhUk5TJ2B4ngqj77/IhSwCbpLIhOy4csvUOhd3XXbLtEidcAQEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALxgAAEAvDhro3hSqSolK+JOtcWoexTP4GDWtI6g4B6D0TfQZ+r99l73+JbBQVtMSbLC/XeLg3v6Tb2bHY/Le6ZMme5cW9c2w9Q7NmCIWKlwj7ORpKnzLnNv3e0eZyNJyaItzqgk9/M2nbad462V7hFF+ZIt0iZUVe1cO7WqzdS7ps49KmngaLep96Geo6b6Qsj93Mrmc6beCrtn4FQlKkyt8xn3x5VY3H0bS3KLBOIKCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAODFWZsFN9h3TMWsW/ZQND/g3DcWMs7ciHtpNGIoljQ06J4dN6mmytS7rso9Eypz3JYF19TWYKqfMvdPnGtf25839f79Lvf6T7fWm3r39rr3bp41z9Q7rCFTfT7nnh1XF9jy2voPueeeJfMFU+/Wevd93ltKmHrH5k5yrs30HjT1/s9//3+m+v373I9PxJCp9i63XDVJyrjHxkmSCoZrkHDB/dhnC275nFwBAQC8YAABALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8YAABALxgAAEAvGAAAQC8OGujeMIhKeKYQFHKDDr3DQyxFpIUllukhCSVQrYonuOGVJP+flvGRpBzj5FpTdlifi793OdM9VPnXO5c+7PH/sXUu6Wq2rk2ks+Yer/zh93u65h5oal3RcNsU31V4B43NXTskKl3suweaZPP2CKEjgy419dNnmHq3dByjnNtZrDW1DtsK1cpnnWuDYVtj0GFgvt9OVQsmXqHAvf6YtF9XBRKbo9XXAEBALxgAAEAvDAPoE2bNunaa69VW1ubQqGQnnrqqRHfv/nmmxUKhUbclixZMlrrBQBMEOYBlE6nNW/ePK1ateqkNUuWLNHBgweHb48//vhHWiQAYOIxvwhh6dKlWrp06YfWJBIJtbS0nPaiAAAT35g8B7RhwwY1NTVpzpw5uuOOO3T06Mk/8CqXy6m/v3/EDQAw8Y36AFqyZIl+9KMfaf369fqnf/onbdy4UUuXLlWpdOKX+3V1dSmVSg3f2tvbR3tJAICz0Ki/D+imm24a/u9LLrlEc+fO1axZs7RhwwYtXLjwA/UrV67UihUrhv/d39/PEAKAj4Exfxn2zJkz1djYqF27dp3w+4lEQrW1tSNuAICJb8wH0P79+3X06FG1traO9Y8CAIwj5j/BDQ4Ojria2bNnj7Zv3676+nrV19fr/vvv17Jly9TS0qLdu3frq1/9qmbPnq3FixeP6sIBAOObeQBt3bpVn/ujLLD3nr9Zvny5Vq9erR07duhf//Vf1dvbq7a2Nl1zzTX6+7//eyUSCdPPCQXv3lyUCu6haqGw7aIvaigPMoZwN0mhsnttfUOlqXdLpXuG3Sc/dZ6p9wWfds92k6Tjh9yz+hLFPlPvmVOnOteWLTtcUkvTZOfaYtZ9f0vSUK97vpck5Yvu/QsZ2926JPc8vd3v7Df1fvW1rc61n77ctk8aWhqca/sHbPl4MdvdTY3nuOcplo2PQaW8Ia/NkAEpSX2He51rcwPuOyVXcFuzeQBdddVVCoKTT4bnn3/e2hIA8DFEFhwAwAsGEADACwYQAMALBhAAwAsGEADACwYQAMALBhAAwAsGEADACwYQAMALBhAAwItR/zyg0VIullSOuM3HTM494yte5Z57JUnRaMy5NhK25TDNbpnkXFuRtP2ucM50989UmveZz5266I+0zplrqt+++THn2mnt7vtEklouusS5Nj55lql3tDLlXDuUdc+7k6RM/4CpvufAPufa4z22vLZSYci5NllTYerd2Oh+/9l34BVT7+bWKc61xSHb8QkyOVN9KH3cubYUZGxrcQ3FlJRMuO9vSYq3uNf3J0LOtdm8Wy1XQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAAL87aKJ5YJKpYxG15xwfco0RKWfc4CUlKViadayNh98gMSWpqqHSu3Xew19R71ieXONdOvcS99l22uJzCQNq5NlXjHn8jSZPP+4RzbTpab+r9+iu/dq7NZdy3UZL6+3tN9Ufe2etcGynZIqEqKtwfBqbMcI+/kaS55812ri1Gqky9Y5E699p4wdQ7ms2a6ofefse5tlwsmXoXDZcJg5GIqXdlg/s+b25rcK7NZN22kSsgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBcMIACAFwwgAIAXDCAAgBdnbRZcPptTuOyWJ1SZcN+MUIUtKykWLjrXBiX3WklKVruv5X/f+L9NvT+9dKFzbW1js6l3zx9+a6qPGPZh70Cfqffht3Y61x4YsGVwbXjqKefa6mTM1DubGzTVtzS7Z+TV1tgy1fbs3+dcmzccS0mqbzvHufa8S+abequUcC491rvf1HrImBl5POO+X0KB7WE3myk71w4GtjzKYNA98+6COve+Wcc4Qq6AAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABenLVRPOUgr3LgGEHhGNkjSaGie6yFJBWDgnvvkC0GoyJR61z7ifm2mJJEzD0a5o3tr5h6Hz+w21Sfy7nHfQwcP2bqvW/XG861g0HS1DtWcl93ddQW8VRbYYvLmTzJPYrnYE+3qXex4H6ODw3YIoT27dlrqH7d1HtwcMC5tiJqu28WE02m+qNF9/tyMllh6l1Z437eJqPu8USSNDDU71xbLLvHDRUdH5O5AgIAeGEaQF1dXbr00ktVU1OjpqYmXX/99dq5c2QYZDabVWdnpxoaGlRdXa1ly5app6dnVBcNABj/TANo48aN6uzs1JYtW/TCCy+oUCjommuuUTqdHq6555579Mwzz+jJJ5/Uxo0bdeDAAd1www2jvnAAwPhmeg7oueeeG/HvNWvWqKmpSdu2bdOVV16pvr4+Pfroo1q7dq2uvvpqSdJjjz2mCy64QFu2bNHll18+eisHAIxrH+k5oL6+dz+7pb6+XpK0bds2FQoFLVq0aLjm/PPP17Rp07R58+YT9sjlcurv7x9xAwBMfKc9gMrlsu6++25dccUVuvjiiyVJ3d3disfjqqurG1Hb3Nys7u4TvzKnq6tLqVRq+Nbe3n66SwIAjCOnPYA6Ozv12muv6YknnvhIC1i5cqX6+vqGb/v2uX86IwBg/Dqt9wHdeeedevbZZ7Vp0yZNnTp1+OstLS3K5/Pq7e0dcRXU09OjlpaWE/ZKJBJKJGyvXQcAjH+mK6AgCHTnnXdq3bp1eumllzRjxowR358/f75isZjWr18//LWdO3dq79696ujoGJ0VAwAmBNMVUGdnp9auXaunn35aNTU1w8/rpFIpJZNJpVIp3XLLLVqxYoXq6+tVW1uru+66Sx0dHbwCDgAwgmkArV69WpJ01VVXjfj6Y489pptvvlmS9N3vflfhcFjLli1TLpfT4sWL9cMf/nBUFgsAmDhCQRDYQpLGWH9/v1KplLr+8jOqiLvNx2P733LuH0/WmdZTKrrnZBXknpUkSdNmn+veO2TLMatvnnHqov/W1Gp75WF+qM9Unz60x733UUt2mDRtxjTn2kLMlr/2+1dfc67NDBw39U5W2p73DMXc/1qezuZMvQO559jlg5Cpd0jumYTVSfc8NUnKFTPuxTFbVl8pbKt/Z+AP7sVVeVPvyoT7dUJF2fa0flJx59oL5p7nXDuUKejG//P/1NfXp9rakx9XsuAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF6c1scxnAnlckjlslvsRzzqHptRES3bFhJ2jx4JIraol3LePebnyJETf6DfyQwedq9PFmyfQls2RLdIUv2kBufaurbJpt7FknvszDsHbPswkHtKVThsuyvli7bYpkjIPdKmqqLS1LtouEtELMWSFHLfh6W8LeIp7Pj4IEn9Q7aopHzCEPMjqabN/TxMJ3tNvQfK7tE92bTtmqKhdqZzbWOT+/04nXZbM1dAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC/O2iy4cCihcMhteRWJpHPfQLYMrqqke65WVU2jqfdQIetc21ATN/WOGrYz39dj6l0O29YyFHPPD2tunmFbS949J2vO3Kmm3r/6xXrn2nwwZOodC7nnmElSZtC9f21Nral3POr+MBAJ2bLgBrPu5/ieg7a8tt5e93M8F0qbek8+z/a7+ZQ698egfGC7/xw/4n7s41n3zEBJqprinu+WGSq512bcarkCAgB4wQACAHjBAAIAeMEAAgB4wQACAHjBAAIAeMEAAgB4wQACAHjBAAIAeMEAAgB4cdZG8cSiIcWjbvNxKJdz7hupqDKtoxxJONcOFTKm3pFY4FybiLtHfUhSLOa+nfHKlKl3qta2D7sPu0f9DE2xxeU0tc92rn3n0BFT74suvcK5dvDwAVPvP/z+dVN9erDXuTYasZ2HqZR7dE9Itiieg++475e9b/eZeocT7udhbbN7pJYkTa63xRmFDJFDoWO2+8+k4+4P01Oa6k29p9a53992vdHtXJvJFpzquAICAHjBAAIAeMEAAgB4wQACAHjBAAIAeMEAAgB4wQACAHjBAAIAeMEAAgB4wQACAHjBAAIAeHHWZsE1NYRVWeE2HwtHjzr3zZRsWVbptHttEC6Zekej7ru/trbB1DseiznXZtL9pt7JmPG0ybvXb/3Vr0ytZ85xz5nbv989y0qSwuGQc21lwn1/S1LEkDEoScmke35YetCWBZfJuNcXi3lT7+qk+3Z++n+dZ+pdUeOe11aMFE29S4UhU31mn3sWXHigwtS7qbLGufZ/nXeRrXdds3PttoN7nGuzebf9zRUQAMAL0wDq6urSpZdeqpqaGjU1Nen666/Xzp07R9RcddVVCoVCI2633377qC4aADD+mQbQxo0b1dnZqS1btuiFF15QoVDQNddco/T7/k5166236uDBg8O3hx56aFQXDQAY/0x/zH/uuedG/HvNmjVqamrStm3bdOWVVw5/vbKyUi0tLaOzQgDAhPSRngPq63v3A6Tq60d+CNKPf/xjNTY26uKLL9bKlSs1NHTyJ/RyuZz6+/tH3AAAE99pvwquXC7r7rvv1hVXXKGLL754+Otf/OIXNX36dLW1tWnHjh362te+pp07d+pnP/vZCft0dXXp/vvvP91lAADGqdMeQJ2dnXrttdf0y1/+csTXb7vttuH/vuSSS9Ta2qqFCxdq9+7dmjVr1gf6rFy5UitWrBj+d39/v9rb2093WQCAceK0BtCdd96pZ599Vps2bdLUqR/+meILFiyQJO3ateuEAyiRSCiRsL0nAgAw/pkGUBAEuuuuu7Ru3Tpt2LBBM2bMOOX/s337dklSa2vraS0QADAxmQZQZ2en1q5dq6efflo1NTXq7n73neWpVErJZFK7d+/W2rVr9Wd/9mdqaGjQjh07dM899+jKK6/U3Llzx2QDAADjk2kArV69WtK7bzb9Y4899phuvvlmxeNxvfjii3r44YeVTqfV3t6uZcuW6Rvf+MaoLRgAMDGY/wT3Ydrb27Vx48aPtKD3TJ0aV3XSLV8rFXLPVtq1z5bx1HP4w7f5j+VLtueyqqvdd396qM/Uu1QedK6NGF+Nf+ywe/aeJA0MuudwZQu27YwE7vU11ZNMvXu6jznX7k+7Z4FJUjlwz5mTpObJ7lmAoXLB1Pt473Hn2kSV7RyvS7nnmMUjtvMwlzdkL0ZtWX3pnG0t+UH3/lVlW+/Z7e7vqWxrsWVG7tvvnqV49LD7Y2eu4HZsyIIDAHjBAAIAeMEAAgB4wQACAHjBAAIAeMEAAgB4wQACAHjBAAIAeMEAAgB4wQACAHhx2p8HNNZq62KqrnSLt8gYIiImNUVsC6mqdC490pMztc7m88610XitqbehtcqOsRnvKZRs29mXcY96qUraol6yQ+4ROJnsEVPvvGG/lIz7MAhs5+Fgv/s5XlubNPWurU0512YytiirI0fdj311dZWpdyjs/vtzqOgeqSVJ8ahtHybc08AUj9uO/Tmzz3GuzQzZtnPTpjeca3f8/pBzbbFUdqrjCggA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgxVmbBRepiCpa4ba8itq4c9/6atvMjWbcc89iSbf8o/f0Hzfs/pJt3cmKJvfWMdu6S7leU3280n07Y1H3YylJkYh7Vl8usG1nvuAeqBcEIVPvkC2yS0HePfOu5F4qSYpF3TIXJUlxW1Zf73H3LLhMvmDqnapzz0eMGnLjJClsPA+HVHSu7TkyYOp9fNC990C6z9T7xQ2/c67tMcQAlstuJzhXQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAALxhAAAAvGEAAAC8YQAAAL87aKJ70YFShsmNESKTauW91lS2nJJZ0z0ypSlSYeqdS7tEwg/0ZU+/B/h732qGSqXcha6uviTc411bEDLEwkoo596ikaNT2+1bcUB5LREy9QyHbWiqr3e+qYeO9ulhyj3qJJ23Na+vco5KOHbNF1AwYopVq693PQUkaKrrHMEnSm28dda793av7TL2b690jh5qnuu9vSVLYfR82pmqca0vlst4+furHWq6AAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF4wgAAAXjCAAABeMIAAAF6ctVlwB/ZJlY7Rarle9wy2msnuuVeSVJEsONem3CPpJEn19e67fzA9ZOrd2+tef/xo3NT7uHvslSQpUnbPSSsH7tl7klQqGXLpyrYMO8tvZ6FwyNQ7ErXd9TIl99UEtlNcsbL7OV4cOmbqXcq4n4elqC0HsHfQvXfeduh1zJi9+NYu9ztF79G0qXc+7b74llSLqfcF06c411p2SaFU1m/eOvW5whUQAMAL0wBavXq15s6dq9raWtXW1qqjo0M///nPh7+fzWbV2dmphoYGVVdXa9myZerpcU9lBgB8fJgG0NSpU/Xggw9q27Zt2rp1q66++mpdd911ev311yVJ99xzj5555hk9+eST2rhxow4cOKAbbrhhTBYOABjfTH+Ivvbaa0f8+x//8R+1evVqbdmyRVOnTtWjjz6qtWvX6uqrr5YkPfbYY7rgggu0ZcsWXX755aO3agDAuHfazwGVSiU98cQTSqfT6ujo0LZt21QoFLRo0aLhmvPPP1/Tpk3T5s2bT9onl8upv79/xA0AMPGZB9Crr76q6upqJRIJ3X777Vq3bp0uvPBCdXd3Kx6Pq66ubkR9c3Ozuru7T9qvq6tLqVRq+Nbe3m7eCADA+GMeQHPmzNH27dv18ssv64477tDy5cv1xhtvnPYCVq5cqb6+vuHbvn22j6sFAIxP5vcBxeNxzZ49W5I0f/58/frXv9b3vvc93Xjjjcrn8+rt7R1xFdTT06OWlpO/Nj2RSCiRSNhXDgAY1z7y+4DK5bJyuZzmz5+vWCym9evXD39v586d2rt3rzo6Oj7qjwEATDCmK6CVK1dq6dKlmjZtmgYGBrR27Vpt2LBBzz//vFKplG655RatWLFC9fX1qq2t1V133aWOjg5eAQcA+ADTADp06JD+4i/+QgcPHlQqldLcuXP1/PPP60//9E8lSd/97ncVDoe1bNky5XI5LV68WD/84Q9Pa2GlWINKMbc/zRXin3LumyvnTOsIF48411akbHEsdZPdI4QmhW35KvVDZefa3mNJU+/eI+7ROpKUSbufZqWiLRZIgftFfLnovk8kKZvJOtfG47Z1R6K2fTiQdV97ZtB93ZIUC/LOtTXhGlPvctj9Va2Fgu0ZgUSVe2xTheNjyXvq4u77RJJmqs659pJ5Vabec+bOc64957+fHnF12eXucUb7Dww61+byRek3b52yznTEH3300Q/9fkVFhVatWqVVq1ZZ2gIAPobIggMAeMEAAgB4wQACAHjBAAIAeMEAAgB4wQACAHjBAAIAeMEAAgB4wQACAHhhTsMea0HwbrzGUNY9CiNjqA3FCqb1lMvuETjhIVsUTzRtWEu4ZOqdzrhHt6Qztn0yZIiFkaRM1j0yxbC7/9sYRvHk3PdLKbAd+0jJdjwzOfd9mM3bjmcQuNdHjZFQ2bx7fc567EPu+yQS2KKPcgXbYvJF9+MZM/a2PBYOpm0xTBnDOZ6zHMv/3sb3Hs9PJhScquIM279/Px9KBwATwL59+zR16tSTfv+sG0DlclkHDhxQTU2NQqH/+a2yv79f7e3t2rdvn2praz2ucGyxnRPHx2EbJbZzohmN7QyCQAMDA2pra1M4fPK/Upx1f4ILh8MfOjFra2sn9MF/D9s5cXwctlFiOyeaj7qdqVTqlDW8CAEA4AUDCADgxbgZQIlEQvfdd58SCdsHS403bOfE8XHYRontnGjO5HaedS9CAAB8PIybKyAAwMTCAAIAeMEAAgB4wQACAHgxbgbQqlWrdM4556iiokILFizQf/3Xf/le0qj61re+pVAoNOJ2/vnn+17WR7Jp0yZde+21amtrUygU0lNPPTXi+0EQ6N5771Vra6uSyaQWLVqkN998089iP4JTbefNN9/8gWO7ZMkSP4s9TV1dXbr00ktVU1OjpqYmXX/99dq5c+eImmw2q87OTjU0NKi6ulrLli1TT0+PpxWfHpftvOqqqz5wPG+//XZPKz49q1ev1ty5c4ffbNrR0aGf//znw98/U8dyXAygn/zkJ1qxYoXuu+8+/eY3v9G8efO0ePFiHTp0yPfSRtVFF12kgwcPDt9++ctf+l7SR5JOpzVv3jytWrXqhN9/6KGH9P3vf1+PPPKIXn75ZVVVVWnx4sXKZm2Bir6dajslacmSJSOO7eOPP34GV/jRbdy4UZ2dndqyZYteeOEFFQoFXXPNNUqn08M199xzj5555hk9+eST2rhxow4cOKAbbrjB46rtXLZTkm699dYRx/Ohhx7ytOLTM3XqVD344IPatm2btm7dqquvvlrXXXedXn/9dUln8FgG48Bll10WdHZ2Dv+7VCoFbW1tQVdXl8dVja777rsvmDdvnu9ljBlJwbp164b/XS6Xg5aWluDb3/728Nd6e3uDRCIRPP744x5WODrev51BEATLly8PrrvuOi/rGSuHDh0KJAUbN24MguDdYxeLxYInn3xyuOa3v/1tICnYvHmzr2V+ZO/fziAIgj/5kz8J/vqv/9rfosbIpEmTgn/+538+o8fyrL8Cyufz2rZtmxYtWjT8tXA4rEWLFmnz5s0eVzb63nzzTbW1tWnmzJn60pe+pL179/pe0pjZs2ePuru7RxzXVCqlBQsWTLjjKkkbNmxQU1OT5syZozvuuENHjx71vaSPpK+vT5JUX18vSdq2bZsKhcKI43n++edr2rRp4/p4vn873/PjH/9YjY2Nuvjii7Vy5UoNDQ35WN6oKJVKeuKJJ5ROp9XR0XFGj+VZF0b6fkeOHFGpVFJzc/OIrzc3N+t3v/udp1WNvgULFmjNmjWaM2eODh48qPvvv1+f/exn9dprr6mmpsb38kZdd3e3JJ3wuL73vYliyZIluuGGGzRjxgzt3r1bf/d3f6elS5dq8+bNikRsn1NzNiiXy7r77rt1xRVX6OKLL5b07vGMx+Oqq6sbUTuej+eJtlOSvvjFL2r69Olqa2vTjh079LWvfU07d+7Uz372M4+rtXv11VfV0dGhbDar6upqrVu3ThdeeKG2b99+xo7lWT+APi6WLl06/N9z587VggULNH36dP30pz/VLbfc4nFl+Khuuumm4f++5JJLNHfuXM2aNUsbNmzQwoULPa7s9HR2duq1114b989RnsrJtvO2224b/u9LLrlEra2tWrhwoXbv3q1Zs2ad6WWetjlz5mj79u3q6+vTv/3bv2n58uXauHHjGV3DWf8nuMbGRkUikQ+8AqOnp0ctLS2eVjX26urqdN5552nXrl2+lzIm3jt2H7fjKkkzZ85UY2PjuDy2d955p5599ln94he/GPGxKS0tLcrn8+rt7R1RP16P58m280QWLFggSePueMbjcc2ePVvz589XV1eX5s2bp+9973tn9Fie9QMoHo9r/vz5Wr9+/fDXyuWy1q9fr46ODo8rG1uDg4PavXu3WltbfS9lTMyYMUMtLS0jjmt/f79efvnlCX1cpXc/9ffo0aPj6tgGQaA777xT69at00svvaQZM2aM+P78+fMVi8VGHM+dO3dq79694+p4nmo7T2T79u2SNK6O54mUy2XlcrkzeyxH9SUNY+SJJ54IEolEsGbNmuCNN94IbrvttqCuri7o7u72vbRR8zd/8zfBhg0bgj179gT/+Z//GSxatChobGwMDh065Htpp21gYCB45ZVXgldeeSWQFHznO98JXnnlleDtt98OgiAIHnzwwaCuri54+umngx07dgTXXXddMGPGjCCTyXheuc2HbefAwEDwla98Jdi8eXOwZ8+e4MUXXww++clPBueee26QzWZ9L93ZHXfcEaRSqWDDhg3BwYMHh29DQ0PDNbfffnswbdq04KWXXgq2bt0adHR0BB0dHR5XbXeq7dy1a1fwwAMPBFu3bg327NkTPP3008HMmTODK6+80vPKbb7+9a8HGzduDPbs2RPs2LEj+PrXvx6EQqHgP/7jP4IgOHPHclwMoCAIgh/84AfBtGnTgng8Hlx22WXBli1bfC9pVN14441Ba2trEI/HgylTpgQ33nhjsGvXLt/L+kh+8YtfBJI+cFu+fHkQBO++FPub3/xm0NzcHCQSiWDhwoXBzp07/S76NHzYdg4NDQXXXHNNMHny5CAWiwXTp08Pbr311nH3y9OJtk9S8Nhjjw3XZDKZ4K/+6q+CSZMmBZWVlcHnP//54ODBg/4WfRpOtZ179+4NrrzyyqC+vj5IJBLB7Nmzg7/9278N+vr6/C7c6C//8i+D6dOnB/F4PJg8eXKwcOHC4eETBGfuWPJxDAAAL87654AAABMTAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgBQMIAOAFAwgA4AUDCADgxf8H/IlN+ZvxeyIAAAAASUVORK5CYII=", - "text/plain": [ - "<Figure size 640x480 with 1 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "import torch\n", "\n", @@ -62,121 +44,22 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "no: 8\n", - "patches 10: tensor([[[0.2314, 0.2431, 0.2471],\n", - " [0.1686, 0.1804, 0.1765],\n", - " [0.1961, 0.1882, 0.1686],\n", - " [0.2667, 0.2118, 0.1647]],\n", - "\n", - " [[0.0627, 0.0784, 0.0784],\n", - " [0.0000, 0.0000, 0.0000],\n", - " [0.0706, 0.0314, 0.0000],\n", - " [0.2000, 0.1059, 0.0314]],\n", - "\n", - " [[0.0980, 0.0941, 0.0824],\n", - " [0.0627, 0.0275, 0.0000],\n", - " [0.1922, 0.1059, 0.0314],\n", - " [0.3255, 0.1961, 0.0902]],\n", - "\n", - " [[0.1294, 0.0980, 0.0667],\n", - " [0.1490, 0.0784, 0.0157],\n", - " [0.3412, 0.2118, 0.0980],\n", - " [0.4157, 0.2471, 0.1098]]])\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfYAAAGFCAYAAAAPXdHTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcwElEQVR4nO3c+bftd13f8T2febjzzR3InDTBJCCDmFIRta2Ci6K2i2qV1lKpNClYFqsqBbTFVqpFWlEoilJbXFWXYgS1A6IUgcAyARWQEMhwk9zc3PkM+5x9zh77J+TV3ftdJ+uzHo+fn2sP3+/+ft/n+8N51yeTyaQGABShsdcfAAC4cgx2ACiIwQ4ABTHYAaAgBjsAFMRgB4CCGOwAUBCDHQAK0krDF7/km6Nube1S1M00xlG3v5Ptz/noZ78UddN4zcteEHWH9i9E3cHVxajrNNtR9/r/9LtRN633/Mvvy8Jm9nO6dHkt6vrD7Nz/m/d/JOqm8a43fX/UNUaDqNvd3Y26nZ2dqHvbr/7PqJvGz9z1yqgb1UZRt93rRt3K6nLUvennfi/qpvUf/sV3Rl1/tx91zVp2PTebzah707s/HHXTeN+PZ9f8wkJ2z2u3Z6OuFx7Lf/S290XdND7w9tdlYSO736W/j+GkHnV3vf0/P23jiR0ACmKwA0BBDHYAKIjBDgAFMdgBoCAGOwAUxGAHgIIY7ABQEIMdAAoSb5770l9lm93WLlyIuv3ZIqJa/UAYVmj73Kmoq88djrqtcbadrzvKNq9V7exjX4m67Z1sw9J2L9u+Nhhl2wmr9NUvfDbqZlvZuRoOs+/UDLdaVenS6YeibntnK+qG4+z3Ud85EHVV6154POoG4TbBuVZ2L+uGm8qqdOrL90Xd/Hy2ea7eyLbu1cNtm1U68/AXo257J9s2ORxkXbM1E3UJT+wAUBCDHQAKYrADQEEMdgAoiMEOAAUx2AGgIAY7ABTEYAeAghjsAFCQeL3VXKueheHynKvDjXLXHFnJXrBChw8djLq5dAtTPTuWvd2dqKva5na2WWsSfq/O3Fz2xsNnwOa9Zva378r++agbDrLv1GmHx6hCs/OLUdfsZBf9bj/7PQ+G4b2mYs1mdo9qLWTnajY8TsN6tsmvSu3w3A9r2blqhqd0cSG7jqo0GmfbIQfDbKNcI/zumxvrWZi85xV7JQBgzxnsAFAQgx0ACmKwA0BBDHYAKIjBDgAFMdgBoCAGOwAUxGAHgILEm+dm68OoW1rKXvKm4/ui7sBcM+qq1B5nG7O6l/pRNxpnf0/1trNjXrV+bzvqllezbVWtcAPX2vpm1FVpNrxC9i9lG7M2N7KtYv2dvd8+1ttci7pJuH1scSHbzDjo96Kuao3RKOraM9nveTTKNpW10jVtFRqHn7XT7kRdY5zdy3a7l6OuUsPsfj8TjqZhuMlufSvb8JnwxA4ABTHYAaAgBjsAFMRgB4CCGOwAUBCDHQAKYrADQEEMdgAoiMEOAAWpTyaTyV5/CADgyvDEDgAFMdgBoCAGOwAUxGAHgIIY7ABQEIMdAApisANAQVpp+LxrD0fdXG076p597ZGou2q5HXU/ec8DUTeNt77ixqgbha83qWeHvT/MVgy88w+q++61Wq1293dcH3X7V1eirlXP/p48e/5C1P3iRx+Numn882+/LupWF+eibjTIfiWb29l19O6PPhZ103jNtxyNusW55aib6SxGXbO2E3U//bt/FXXT+ulXPy/qOkurUdfbyr7XfHM26t74Kx+Lumm84weeH3U7O/2o6w+GUTeuZfe8d9zz5aibxg9/27VRt7Y9iLrudvbddwbZffGPvvTk0zae2AGgIAY7ABTEYAeAghjsAFAQgx0ACmKwA0BBDHYAKIjBDgAFiRfUHFrNliYstZtRNzubdY1mtrCgSosL2fKRwTBbPjKu1aNuMsmWP1StPs7OwaifLWwYT7JuMtr77z+pj6Nus78VdaNR9rvfHmXvW6Vh+Bk2t7LzefpSdozajb3/7rVarfbomWyhzOCpbJFSbz1bOvSsgzdEXZXarRNRV19aj7rdyxejrtvNfiNVWt/MzvuF9V7UPfp4doxGzXgcPy1P7ABQEIMdAApisANAQQx2ACiIwQ4ABTHYAaAgBjsAFMRgB4CCGOwAUJB41c2xQwtRt9wZRt3ifCfq6uGWsmplm9fqk2xj1m4v20DVCDfUVe3A0lLULSxk2wk31rNNXSvLy1FXqUm2Ke7U6ew7dXez1+s8A5avzYRLwFrtcAPXxbWo2w2PedW+cuqpqFtZzq6PO299ftRtnMk2WFZpkt2iaisH21G3u52Nmm537581Z9rZdzp5NDvvhw8fibqzG9nGu8TeH0UA4Iox2AGgIAY7ABTEYAeAghjsAFAQgx0ACmKwA0BBDHYAKIjBDgAFiTfP7V+ay16wvxZ1M+3sredn5qOuSqPwMA3G2da91dV9UTeZZBvvqtZqZn//DQbZ5qT5xcWoe/L8btRV6aFH16Pu/GZ27rezrHb13N5vX3vl854TdSeuys7nb9//cNTd+7Vs41vV6vXs+ms1sm5z7XzUbXf3/nc/HFzMwlG2HXN2Ntvm1pnd+9/9/Hz2WYej7GJ+1sljUbd0aTPqEp7YAaAgBjsAFMRgB4CCGOwAUBCDHQAKYrADQEEMdgAoiMEOAAUx2AGgIPXJM2W9GQDw/80TOwAUxGAHgIIY7ABQEIMdAApisANAQQx2ACiIwQ4ABWml4c/80N+Mut6l09kbz81H3XA4irqf+I3PR900fuy7b4u6Vj1bCdBu1qMu/avrrb/1hbCczpu/65aoW923HHX9UXac7r3vkaj7gy+dj7ppfMOxfVE3DK+kfnhWb53Pfve/+cBa9sZTuO+ub4y65WYn6j57cS7qPvHouaj75U99Luqm9Ypnn4i6cT07V5Nh9r6N4TjqPvK1s9kLTuHub7sx6k6cPJK9YHhv7Pd7Ufe2D1Z37n/y+54bdTv9QdR1mtk1Pze3EHU//l/ufdrGEzsAFMRgB4CCGOwAUBCDHQAKYrADQEEMdgAoiMEOAAUx2AGgIPGCmn0HD2XdYraEotFoR93axuWoq1JrNvxOo2xRxbiWLaCYtOPTU6nFfQejblCbjbovP/xg1G3tbkVdlWazr1Sb7WTnam4hW8y0rxluM6nQ/eEClGE/++67K0ej7tC+8KBX7OhqtnBpMNyJuu1w+crWdrbMpUr9Yfb7qw/62QtmO7lq7UYYVqjTyn7P7bAb7u5G3SRc3JXwxA4ABTHYAaAgBjsAFMRgB4CCGOwAUBCDHQAKYrADQEEMdgAoiMEOAAXJV5uFm+Lq7axLzcxe2debxvzCQtS1wr+TGo2sG4Qb6qo2u5xtHbzw1GbUbV/Itglet3/vN5CdWMw2Yc2GG+Vuvv541DV2s21mVRruy877RrgdstVcj7qlTna9Ve3m666PuutvfFbUPfLYn0XdAw+ejroqzbTCbWmTbtQNh9moabQ6UVeldiubOeNxdn8eh2v36vUr95ztiR0ACmKwA0BBDHYAKIjBDgAFMdgBoCAGOwAUxGAHgIIY7ABQEIMdAAoSb57r7Qyirj7oha84jKqtrY3w9arT38m2MA0b2aa07na2oW0j7Kq2tbMVdZNh9nmvPphtYrr+2N5vHbz92mwT1vGb7oi6ziTbKHd5PbveqrR08kAWXmxG2cmjV0Xd2lb2e6vaN9x5Y9Qt78u2Di7vuyXqLp9/Blz32e251g63BDYmM1E3GI+yN65QuFCuNhpkB6mR3e5qk8kkC5P3vGKvBADsOYMdAApisANAQQx2ACiIwQ4ABTHYAaAgBjsAFMRgB4CCGOwAUJD65EquuwEA9pQndgAoiMEOAAUx2AGgIAY7ABTEYAeAghjsAFAQgx0ACtJKw3e96VVRN+ldzrrw3+dbtWHUveE9fxJ10/iFN7w86p4834u6rz5yKupa7ewY/fdPPBx103r9d90WdTtnz0fdjYezc/ryv/3Xou7W130y6qbxkbfcGnVLx78x6g4eOBp1586fjbpvuev9UTeNz/36G6KuMW5HXafRjLpz509H3Tf/yAejblqf+cA/jrqt3ey6P/1YN+pOPfhk1L31t++Pumm89RU3R93M/FLUTVrzUVdvZPe8t/zXP426afz0q++Muvp4EHbZ/W7UmI26N3/w3qdtPLEDQEEMdgAoiMEOAAUx2AGgIAY7ABTEYAeAghjsAFAQgx0AChIvqFldXYy6YSv7Z/xudyfqJoNR1FVp/dJ61J16LFsq0u1miyrmZp8Zf3edeeRM1B2Z7UTd8eNXR93qsWujrkrt5dUsnM2WtJy444XZyz2VLWmp0pGj+6NuVMuu5a2trLtq/lDUVW1hKVu+Ul/I7o0nFo5F3dJqtsSoSjfeclPUnTt7MeoG9ez62OnvRl2VRuHytIWZbKFMv5fd79ud7BglnhmTAwC4Igx2ACiIwQ4ABTHYAaAgBjsAFMRgB4CCGOwAUBCDHQAKYrADQEHizXOba9mGoVZ/M+ra9fBvimaWVanVzD7EdjfbULdvaSHqVheyzUZVW2zNRN3hYwei7vjtL4m6Lz7Rj7psn9d0Hvxa9hnuvCrb0ra2lr3ekevviLoqrd58Z9T1d89nrzcZR93GuexeU7WVgyej7qr94bkfZddR+/Z9UVelF77slVH3qT/8cNQ98Xj2G2lewe1rU2tmY7GXLairDcLn58ZgkL1g8lpX7JUAgD1nsANAQQx2ACiIwQ4ABTHYAaAgBjsAFMRgB4CCGOwAUBCDHQAKEm+ea9azbtTrRt2klr1gozbM3rhCo/pO1F0OFwdtbGQriya72Zayqu1bzjblveClL426Eze/KOo+9IFfjbq/dVeUTWVfazHqmv1e1J1++KGoO3rdrVFXpYVjt2fdJNs2uX3pXNTNjfd+81qtVqu1910ddRc2t6Nu9dC1UXfg6DVRV6VjN94cdY3l7PVGneweWm+Eg6ZCu6Pwsw5HWTfJuuEwHsdPyxM7ABTEYAeAghjsAFAQgx0ACmKwA0BBDHYAKIjBDgAFMdgBoCAGOwAUpD6ZTLI1aADAM54ndgAoiMEOAAUx2AGgIAY7ABTEYAeAghjsAFAQgx0ACtJKw/e86ZVR17/wWNTVG9nfFDPhnx4//P77snAK//bvf33U/dFfPBF1Bw8cjLoT88Ooe9f/fjDqpvXBN39n1L34e/5B1F0+14u6P/3gL0fd63/93qibxqf+3d+Lutmjx6Ju9erbo+7ITc+PusWTd0TdNLYvPR51/eFW1A16G1E32s1+H0dveknUTeuTf/i+qPvCF7N7z50vemHU3fqcb4q69vLNUTeNJ77yoaj7+B/+WtRdWr8UdeNwLvzI2/5P1E3jnW95UdQNw9/p+vm1qNvdnI+6d/7Ol5+28cQOAAUx2AGgIAY7ABTEYAeAghjsAFAQgx0ACmKwA0BBDHYAKEi8oGY8HEVdb3ccdZ2FxahrtdpRV6VmYyHqbji6L+pm57K/p665+mTUVe2OF7806q66OVu+8uf3fiDqnnUyO55VOvrs26Kuc+j6qGvNr0Td9k436rKraDrnz5yKurNPZotsLp/NFjiNBttR97KKF9QMttei7uDB7B71+JOfj7ojVx2PuqMVLqgZbme/v0lvN+rqW5ejbjTJlr5UabJ7LurmZrLz3jmadRsz9ahLeGIHgIIY7ABQEIMdAApisANAQQx2ACiIwQ4ABTHYAaAgBjsAFMRgB4CCxJvn2s0svbyZbY0a7WRbdubm56KuSr3GMOoOH5iPusfPrEXd9V//7VFXtRO3pZ8j2xQ32NyKupWlbEtblQ7d9Jyo22rtj7ovff7Pom63lx2jl9/416NuGp/+43ui7sLpx6KuOepH3exsfFuqVH/jyai7/aYbom7YzDZYtpurUVelpU527bV2dqJu+9TpqEs3nFap98iFqOs2m1E3fyA770eOHYi6hCd2ACiIwQ4ABTHYAaAgBjsAFMRgB4CCGOwAUBCDHQAKYrADQEEMdgAoSLziabeXbRian8lesj6bbe1ph1vfqtQZ70bd3GL2nV7xqldE3Z3f8a1RV7XlQ0ei7uzDX466ZnhO1zbXo65K5x/9StQ9uZltzPr4PfdE3eJcO+pe/g/fEnXT+MpnPxZ1R49kW8qWl7INXI888XjUVe2RU6eibv+xa6Luptuel73xaCbrKrSxlt3ztsMNopd72TVfn+z91sFLl7Pv1J1Mom7SzWbnLatRFvHEDgAFMdgBoCAGOwAUxGAHgIIY7ABQEIMdAApisANAQQx2ACiIwQ4ABalPJuH6HADgGc8TOwAUxGAHgIIY7ABQEIMdAApisANAQQx2ACiIwQ4ABWml4X983bdG3dpTT0Rdo7OQvfFkFGVv+62/yF5vCm9/1e1RN7+6HHV/59Wvjbqrbnhu1C0cuS3qpvWFz3w06i4/+VDUnf7ip6Ju88xXo+617/1M1E3j4+/+wajrTuai7v5PfTrqDqxm18fd78uO5TTec9eLo+7IVQei7szZp6Lu4kYv6n7i1/4y6qb1c2/8tqi75sbs+jt57XOirtvdjLqX/t27o24a997z7qj7y/s/GXX33Zd1c3OzUffzH8ruNdP4se/9uqhrtGaibnN7I+pO3JjNjx99x/1P23hiB4CCGOwAUBCDHQAKYrADQEEMdgAoiMEOAAUx2AGgIAY7ABTEYAeAgsSb52q1cVYN+9kbt+ejbjTMNs9VaXeQfacjK/ui7n99+Pejbv+RL0Xd977x30fdtM498pWo62+vR127nW1sWlzINjFVqdVoRt1Cux11Rw9nW9p6m5ejrkrNcOvjxfMXom7Qz15vaTbb4le1frcbdV/9/H1Rd+aBB6Nud5ht3qt089zHPxx1o/T6OBFuGl3I7rVVajxrJ+pmx8Oo21fLfs+3PPvaqEt4YgeAghjsAFAQgx0ACmKwA0BBDHYAKIjBDgAFMdgBoCAGOwAUxGAHgILEm+fG43rUdVrZJqLZVrbJrtbI3rdKK3PZ1qRxfxB1Fy48FXXd81lXtc3Hsg1441p27vfvy7avrR47FHVV6sxl2+9OP5mdq0ltEnWNxv/DUsiK9IfZZq1mPdu6tzCbbZschreGqs01w3NQz87pqJ9tZmyE99oqddezbYL9mWxL3tKx3ajbmluLuiptrpyOup2t7Ln4wPJ1UXcw3EqZ8MQOAAUx2AGgIAY7ABTEYAeAghjsAFAQgx0ACmKwA0BBDHYAKIjBDgAFiddbNeozUTc7Mxd1k1q21WphLttWVaXDBw5H3fZgJ+oOLHWirhUeo6r1189G3biRfa/tdrZa7MiRa6OuSo2FbBvUzbefiLpP/8nHoq4/2Y66Ko3r2e2h180+6/JStsWv09r7rXu1Wq02284+R3cnu+4fOXM56tbW9v66/8JDZ6Lu0E3Zs+Hx1Wwu9CfZPaRK3W52/Ds74cbF49k9pLc9irqEJ3YAKIjBDgAFMdgBoCAGOwAUxGAHgIIY7ABQEIMdAApisANAQQx2AChIfTKZTPb6QwAAV4YndgAoiMEOAAUx2AGgIAY7ABTEYAeAghjsAFCQVhq+5+6XR93a6a9GXXN2IeoWFrLu7vd/Muqm8d7XfXPUNdvZfw62mu2oa7ez7/4DP/N7UTete37qB6NuZTn7vE+dPxt1N3zdc6PuBa96c9RN49SnfzPqTp+7EHWD7fWo655/Mupe/oZfiLppvPuul0bdhaeeirqVlcWwW46617zrY1E3rV96/TdF3YMPPZp1p7Jz35jJrqN77s9+I9N49cuui7prvy47VzP7dqJu41LWveNnH426abzxn1wVdccP74+6O8L72COPrEXdD/2r33/axhM7ABTEYAeAghjsAFAQgx0ACmKwA0BBDHYAKIjBDgAFMdgBoCAGOwAUJN48d+RQ9jfA4OLFqOuNxlG3tRVllVrfuBR1rVZ2OJeXD0Rdp51tqKtab2sj6uba4c+pn3X3ffrTUfeCV2VvO40H/jz7DE88kW1fazTqUTc/s/fnvtmcibq5uWxT2la3F3W9XtZV7cyZ7JwuzmXH6c7n3hR1s0vZNrcq3X79yagbDbajrvd4tlGusTkbdVU63DwYdc+96dnZ660eibr7zzwSdQlP7ABQEIMdAApisANAQQx2ACiIwQ4ABTHYAaAgBjsAFMRgB4CCGOwAUJB489yzTnaibqWebQ762uPZxqKz5ydRV6XNjWHULS5mh3Nrez3qRuNu1FVtcyPbwHXpfLZ1cLObHc+dQXacqnTqgc9F3dLivqg7+1S2xfCJrWxTV5XOn8vO55FD2SbF+ngQdZfXLkdd1RYWlqJudSXrOs3sOWq3P4q6Kq20ss2HW7vZd+p3s9dbGO/9s+YNJ49F3bGj2e/+8SfORt3F89lMTOz9UQQArhiDHQAKYrADQEEMdgAoiMEOAAUx2AGgIAY7ABTEYAeAghjsAFCQePPc8r5sc1Av3J6z73Aze+OF+ayr0NzibtTt9PtR1+osR134cpXb2so2hg1G2XFa72WbxRbmZqKuSjvb2Qa43s6FqOsPsq1io7Cr0niQbX3sbmTX/PLyXNitRF3V6rXsHnXhYvZ7XlxcyN63sffPW4Nhds13Wtk5nckWktY6nXAuVOiaG66Jut52dn184hN/FXV/+eC5qEvs/S8IALhiDHYAKIjBDgAFMdgBoCAGOwAUxGAHgIIY7ABQEIMdAApisANAQeqTySRbnwMAPON5YgeAghjsAFAQgx0ACmKwA0BBDHYAKIjBDgAFMdgBoCCtNHzof3x/1F189KNRtz7O/qbY6GUf8Xve9HjUTeNXfvx41G1cDg/n6FCULcwejbrXvvv3s/ed0nv/2cuibmt3Leo2ty9G3XxrLup+9L/9edRN4xf/6Uuibncyjrrtbi/qBrvZ6/3r3/lc1E3jJ7/7+VG3MFOPusXF2ahrtdpR90Pv/eOom9Z7Xvs3ou7M2aeirjGTfa+V1eWoe+MvfSbqpvHzd39T1DXaM1F3YSv73W9ubEbdO3/jL6JuGh99/2uibnNrK+p+9p0fjrqz21FWe/jC04ee2AGgIAY7ABTEYAeAghjsAFAQgx0ACmKwA0BBDHYAKIjBDgAFiRfUdLvZcoVaczHKFhd2oq49N8net0LHj2WLNVZWsqUi3Y1sWUN342zUVW19O1vAMdgZRd1S50DUzbbD31yFhru7UddqZX8jd8I/pdszzSys0Pxc9hnmF7PbSCO82wxHwyysWGcu+8DLq/NRd+lStnxlM1x2VKXRILuWt4f9qPvqo9lSqge+UN2isdTnP/dg1B05kZ33WiM7nwdXlrLXS97yir0SALDnDHYAKIjBDgAFMdgBoCAGOwAUxGAHgIIY7ABQEIMdAApisANAQeLNc0+cyrrdtWxL29KhbLvU7Nwge+MKrWTL9Gr792eHs7u1HXVra1lXtUNHtqLucrZcqtYcZxvNxpO93zo4GmUbuGrjrEv/kq436mFZnU47+z33Rtm3moQL5drjvb/ma7VarXf5fNSNetl1OmplmxTXunt/3V84n13Ml8Itmo9+LXu9tYvZvaZK/a3sWj66cjTqbrn6eNSFhzLiiR0ACmKwA0BBDHYAKIjBDgAFMdgBoCAGOwAUxGAHgIIY7ABQEIMdAAoSb54btQ9G3aDz/KjbHe9GXWN4Ieqq1Jq/LupWD2Vb9/Y1shVc+7fHUVe162+7JerWLmQb5Xpb2c9uNOxEXZUOX39r1I2H2bna6e1EXaez99/9wMnro25zJ/vuvW723duTftRVbWXxWNSNGxtRNxhkv/uZhb3fuDi7kn331U52rq6rrUbdbXcsRF2Vbr79jqi75oYbou6FL8o2CT7xZDfqEp7YAaAgBjsAFMRgB4CCGOwAUBCDHQAKYrADQEEMdgAoiMEOAAUx2AGgIPXJZLL3a44AgCvCEzsAFMRgB4CCGOwAUBCDHQAKYrADQEEMdgAoiMEOAAUx2AGgIAY7ABTk/wKbbG0VtSv7gQAAAABJRU5ErkJggg==", - "text/plain": [ - "<Figure size 640x480 with 64 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from torch import nn\n", - "\n", - "patch_size = 4\n", - "\n", - "def get_patches(x, patch_size=8):\n", - " # get the batch size\n", - " patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)\n", - "\n", - " return patches\n", - "\n", - "# run the function on the first image\n", - "patches = get_patches(img.unsqueeze(0), patch_size=patch_size)\n", - "\n", - "# show the patches using patches.shape\n", - "no = int(32 / patch_size)\n", - "\n", - "print('no: ', no)\n", - "\n", - "fig, ax = plt.subplots(no, no)\n", - "for i in range(no):\n", - " for j in range(no):\n", - " ax[i, j].imshow(patches[0, :, i, j, :].permute(1, 2, 0))\n", - " ax[i, j].axis('off')\n", - "\n", - "print('patches 10: ', patches[0, :, 0, 0, :].permute(1, 2, 0))" - ] - }, - { - "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "patch size: 16\n", - "patch size: 16\n", - "patches: torch.Size([4, 3, 16, 16])\n", - "embedding: torch.Size([4, 8])\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcoAAAGFCAYAAAB9krNlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaWElEQVR4nO3cy4+l+XkX8OfcqurUve89ffGMPTds47GDHctyLBIrmyCkyIgFCxb8AdmyRGzYIbFFSEiJkFgCsiIUgZSAFYQjFFsOJL5kMvZ4Zrqnu6cvVV2Xc7+w8BZ+/rZkTUrO57N+9LznfU+d37fezbezXq/XBQD8P3X/uj8AAFxkghIAGgQlADQISgBoEJQA0CAoAaBBUAJAg6AEgIZ+Ovi1f/mvo7nj5Sy++OaoE81dPtiMd36inkZz12Z/GO+82rsazW307sQ7+3tfzAZ3vhbvfDY4iuZm4/8R77y0zj5nt7bindPZYTQ3WdyId/7z39mNZ/l4jX79d6K52XG+c94dR3PLjcN45ztXsu6VP788jXd+Nbz8Z3q9eOejzexv/Vu9nXjns6NBNLda5Pe+PMzO98Uyf197Ps2uP53kPTr/6nf/xc+d8UYJAA2CEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA1xM8/3n62iueNR1phRVXV5sRHNdTrxx6yr3azhojN/Ld55PsxaZ87OXol3rp99Ppobbb8U7xx1/iyam3c+Ee98Ujejua1R9l1WVS0W2ffZ643inVWaeS6qxz+4H809ejKJdx5tZefR8kq+c7jMzo6rw7yd5oNV9rd+Y5nvXHSW0dx6sh3v7Iyzd6Zl+Iyqqta9bOewn39HG4t5NHfSzTMj4Y0SABoEJQA0CEoAaBCUANAgKAGgQVACQIOgBIAGQQkADYISABoEJQA0xD0/w61BNtg7ii/+8jSrYnplPox3Xt/OKpaG/evxzs7WB9HcePFqvHPSvRPNrRfxytroZfVOtXsQ71yfXYvmDubn8c5F9/vR3MZpXutV9fdfYJaP07f7WZXak828BrF3ZRbNbd3If0C3r2XXf2t7J9656GQ7B9NOvHMwX0dz/XxljYbZ8GrxAve+ys6Es8v5d7Q9zz7njcFJvDPhjRIAGgQlADQISgBoEJQA0CAoAaBBUAJAg6AEgAZBCQANghIAGuJmnq1x1pKyt85aOKqq3tjNWhauzMNWoKoadLPWirNB3vazXP52NDfufzXe2Z3uR3P7G4/inf1u1kp0PD3Od9bjaO7y5ofxztPFt6K52fPdeKdmnovrLzun0dzNvXzn/u2sIebdYdbgU1U1W2VNYZefbcY73wh31ij77VZVPdvIzuLRYX5uHm2sornO87N456Sf3fvZXn7v65Psc356kjeFJbxRAkCDoASABkEJAA2CEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASAhrjC7tKTrDpo+AKtYweVVdhd6y3inctutnPZ24p39sZfyAa382qraf8H0Vx/chTv7Ic1g8tVVilWVbU+/iCa++g4r7ZanmXVWqeTF6mw46K6vpnViV3bzGslH+xkf+uLrCmyqqpGq+zs+KDy33n1s3s/6+Z1c1urrKZz0Z/EO5+u59HcMHvsVVW13cvOhOFHo3jn6TybXYyy+0l5owSABkEJAA2CEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASAhriZ51pYjrN3mjX4VFVt3c8u3x0s453D/Vk0N5/fiHeuxv81mlt/4nK8c3b9ejS37GT3U1W1mmbtHuuzvPFm/Tz7jk6f7MQ7lz+9Gc2NluN4JxdX7zD7/T4dZC02VVXzrWfR3F4vfxeYDbMmmb9aPIx3Pqjn0dx0fTXeWYvw7JjtxSt31mED2DJv4Or2swawrVn2jKqqLi2zprBPL3vxzoQ3SgBoEJQA0CAoAaBBUAJAg6AEgAZBCQANghIAGgQlADQISgBoEJQA0BBX2N3azKrp9vsb8cV3f7oVzXWm8cqqwTrbeZBX7U2/+iSa657k9U5X1sNobucFPufJNKvQOzg7jHeebmbdhe/dvR/vPLuefUcbj/wf98tgdi2rputtZBWMVVU720+jucU6+1urqupVJxtc78c7l+PPRHPdehDvPNkbRXOznbwWb+95tvN8Pz/jTifZ9zm5n1VaVlVdmWb1fVdXJ/HOhJMIABoEJQA0CEoAaBCUANAgKAGgQVACQIOgBIAGQQkADYISABriZp7L38gytX8na9upqtr8D1lrxvaf7MQ7p4tsdt7Pmiiqqg5Pr0Rz61H8OGu2N4vm5pU382xvZff04caNeOePt7PrP168Ge8cfeLb0dzL4TPiYlvtfTGaG89einfuD74bzW1svhPv7I2zv/WzVf5+8e7hcTR3vM4bb6bLvWju2nwZ77y9m53bs8f5+X70XtZ0tHGaf86dxd1objx8Fu9MeKMEgAZBCQANghIAGgQlADQISgBoEJQA0CAoAaBBUAJAg6AEgAZBCQANcefa9esb0dw4rFGrquq+3ovmzh7kNW7jR9n1+9Ps2lVVo+9tRnPd7iDeOZ5nz/OwDuKds3oYzf3kIP+Onu1dj+bWk/vxzt76s9Hc/it5rRcXV/dydsycP8trx/qd/WjuYHQ53tmZHUdzDzr57+f9+dvRXLcW8c79eVbTee1SfnZ0ltk7U+devvPSSXZu3+7nVZV3etnsO1vX4p1fD2a8UQJAg6AEgAZBCQANghIAGgQlADQISgBoEJQA0CAoAaBBUAJAQ9zMc+m7N7K5xTy+eHc6juaOr53HO+eddXbteXzrtRplLSDr86xtp6pqd5bd+7wzjHf+MJw7H3wn3rm1/lQ2N7gX7xwusv/PLq23451cXL2r70Rzw928oeW8mzXEjE8uxTsX4TGzu8xbdL5a2T1tDfK2n8Vudv1l5Wfx+CedaK47zd+trm9ls7+ykZ/F13em0dx3e8t4Z8IbJQA0CEoAaBCUANAgKAGgQVACQIOgBIAGQQkADYISABoEJQA0CEoAaMi7gxafiMY6J1m11M+unlVbbe49ilduX8kqo/qv7sY7u+9m3VbzH+W1SZsbvWjuyfpZvHO0OI3mPtWdxDung+xzbq324p1vrrLvs9vJa/G4uB53sxq5G4OTeGensr/Lo628VnKzss95WK/HOze670dz03o73lnD7Ng+f5jVeVZVzZ5ktXg7l/PIeG2SPftbO6t45wevZnNPp1klX8obJQA0CEoAaBCUANAgKAGgQVACQIOgBIAGQQkADYISABoEJQA0xDUL4739aK5T4/zqd7MmjPPjfOXsU0+iucV+3k5ztp89ppOjrBWoqur2UdZGsR7M450vd59Hc6/Od+Kdo1X2nG4v8+99o5s9p6N13nTExbWa3I3mzuZP4537dTObO89ataqqxvO/Fc09CZuGqqp2R9lsZ/Q43tm5m/3ONxYv0Ep0NWvx2biVt+i8Mslmx2/kZ9wfr7PGnf/zPG8lSnijBIAGQQkADYISABoEJQA0CEoAaBCUANAgKAGgQVACQIOgBIAGQQkADXGF3fJvZ6Pr7lZ88fXG5Whu+FH8MWu3P4jmPnx8Eu9893421x8M450bh1k922Sc17i9vs4qq36zkz2jqqofr3rR3F53Gu+82skq7D5a/mJrqPjr0el8IZpbL/N6tOU6+x9/sJVXw9VednYdPz6OV47vPYjmDgb5b7L/KKsT7T7KayVH17Mz9tFv5O9WR/ez7/P0dlZLV1X1h1ey6z96rMIOAD42ghIAGgQlADQISgBoEJQA0CAoAaBBUAJAg6AEgAZBCQANceXN4Y2zaG6RF8nU2fluNLd+KW+teD7biebe6xzEO8/uZPc+3MvmqqoenGazNz6axztvz/eiucNJ3p40qEU2uJO3J92ZZQ1CWy/QVsLFtX14M5vL/4Sqe3YezS3m2/HOjdPs8Np/PztjqqqenWWzp1uH8c79d59Gc6OtR/HOvwrbcX50K2vVqqq6cZq9h93Y2ox31ueztqGrD7MzJuWNEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA2CEgAaBCUANMSlUaeH97KFs2l88cFhVrlWJ3nlWn+YVdONXroc77x0nNXNHQ6O453j8VE0d/1/ZnNVVbfv9aK5v+it4p1vV1br9dWsWaqqqo7nYbVVN6vV4mLbOH4zmhu/wNmx7qyjuUHvo3jnopfVsy3P8rNjuZf9fo533op3zo6Oo7lnt/4g3vnTtz6I5o7v5dVws3eyuZsvZd9lVdWnj7O5k1fz6s+EN0oAaBCUANAgKAGgQVACQIOgBIAGQQkADYISABoEJQA0CEoAaIibeXo7WcPEcidvRFiPsiaO7jDP8+Usa/E5enIY7zw5vBTNra/kjR0vjY+juV/tZnNVVXeWWT3Of+pk32VV1c3w6+x18raf+51Jdu1F3gLCxXWlcyuaOx3nDS3jehbNDSr/u9zrZL+f1fw03jk/uRbNbU5vxju3DsKmsK/vxzs/9Wb2W/vc7+Vn8ZthFrwyW8Q7v/yT7G/k3qfynQlvlADQICgBoEFQAkCDoASABkEJAA2CEgAaBCUANAhKAGgQlADQICgBoKGzXq/z3igA+BvGGyUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA2CEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA2CEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA2CEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA2CEgAaBCUANPTTwX/67DfipU8Wz6K57ngV7+xO19Hc2Tvb8c7//Rc70dyVr+zGO7e/MIjmvvRgGO/8h38Qf0310fPjaO7fd7PnWVX12dlBNHdzOo93vvx4Gs29cTqJd979j/8lnuXj841vfBDNLd8/j3eedU+iuf0axzuXqx9Fc7N734l3Dq5+OZrrLf5uvHP3je9Hczv/+N/FO7fG2Zk9+7f5u1X3TvYb79/Kv6PZW8fRXOdLeQ78m8//8OfOeKMEgAZBCQANghIAGgQlADQISgBoEJQA0CAoAaBBUAJAg6AEgIa48mXxIGuDqKqafvgkmtt5FK+s/nwrmuvu7cU7b169Hs0NP8xaK6qqXn4pa7z52mgj3vnmeBbP/sl51oZxd523In2204nmrm3mbT/bN7PrTzbzViIuph8v34vmzldZg09V1Wz3XjR3ZTCKd9YsO2Oml7L2raqqre73skvfvR3vXHz9LJrbeSs7C6qqBr93lM1dylt0up/LzoP5q/nznB9ks5tn2ZmV8kYJAA2CEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA2CEgAa4n6w3n/LK4FO72dzk2VWGVVVtb11EM11P3k53nnljZ1o7kEvv/e/05tEc5/byOulavICz/69YTS3t8jr5t7oZp+139uOd37vSnb98a3sfqqqXo8n+Tjt9r8ZzW2O3o93TjpZrWNnmFcgbq6zGrnh9mv5zuPsjDn/rcN45+q35tFc51J2Fv3sA2SH9vbBMl45+kI2N7/ci3fWn2bP8/n5lXznr/38EW+UANAgKAGgQVACQIOgBIAGQQkADYISABoEJQA0CEoAaBCUANAQ11aM/9ciXrp5I1u7tXcp3tm9l7U3LLt5G8Xurazd4x99Lv9/4u8ts+d09TheWT+Z7Maz3fFmNHe6Po13/nSczZ2u82aeb56dR3PDw2yuquqfxJN8nI4nfxTNdfpZ+1ZV1U4va2gZLz6Id/YX2W93s/dKvHP+21+M5jq/mf1uq6rmN+5Fcxs/zBu9FuG9Hz3Pm47Wf7nKBpd5S9jqm9n5Ph3EK6v+2c8f8UYJAA2CEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA2CEgAa4j6i9UFeXbS6lNUxLR4N453redZJ1Dldxjs3383mvlhZfV5V1eAs+9/jz/rzeOeHm/mzny6y2qijYdhLV1Xv7GX3tO4fxzuXo+x76q9G8U4upqNZVje3v8wr7Jarh9Hcap7/ztbrs2hua//9eOfki9+P5vYHeaXk+ntZ5dvsx9fjnYvJfnbtZ1vxzskPsvN9Pszr+xaTk2iu28srV6N9v9BtAPBLRlACQIOgBIAGQQkADYISABoEJQA0CEoAaBCUANAgKAGgIa98WeeNDIvHWZPNYJY33iy2sjaKOsvbfg6+l7XD/P6POvHOG4tZNPfSILyfqhp183aRwZ1sdmcn+5xVVd3Njezaq/N45/Ww7ej0s1nTEBfXapX9zuazJ/HO8+1s52qZnwedRdbM8+TSd+Kdo//8djT32vfzpqw7d7PfzvPn+fn6dDNrTzrfz8+NSW8SzY0HeYvO8yvZ99k/+mS8M+GNEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA2CEgAaBCUANMQVdqtFVnFUVdXvZJVE/a2sGq2qqrvKqtl6i7wabtbJ6tGezEbxzserrO5uPjqId9ZmXh946VZWA3ZruB/vXIbP/sPzvIKs9rN6re71fCUXVHgejHuDeOXxajua6+Xtk7Xeyc6DvXl+xnzp4fNobvkw/6Drk6NobtDLa/H6V6bRXPfacbxzMM3q7vYe5u9rN9afiuYm/SvxzoQ3SgBoEJQA0CAoAaBBUAJAg6AEgAZBCQANghIAGgQlADQISgBoiJt5Op1hvHTz7Dgb3IsvX8N+1sSx18laZKqq5pOsMWTvwaV4Zx1nrR3PO/n/KN03JvHs4HA3mruxzpo4qqpmT7J2kbfC51lV9d9vZ3PrUdbgw8X1hc3sd747zBuwfrLImqUebudnzOxq1syznuS/x8cbWYvOdJH/dgaPzqO5zkZ+xgxW2fne+2TepjYIG80Wd/JGplt/lDXunI6W8c6EN0oAaBCUANAgKAGgQVACQIOgBIAGQQkADYISABoEJQA0CEoAaBCUANAQ9zv1+1l1UFXV9CyreNrq5fVSvYOsmm5eL1AFdT+rR9t4L69DGmxmc9s3ZvHO/ctZXVdV1eNJVqF3+9mH8c67R9n9f3Q9r7f61cPsf7THP3iBKqp/kI/y8TntbkVz0+lxvPPmzjia627mv52PxlmF3Wyd/3beHmb1j6udnXjn/kFW53ljJ7/3/qPszN6+k3/O9WvZ+T5/cjneefJrd6K5nQ8fxjsT3igBoEFQAkCDoASABkEJAA2CEgAaBCUANAhKAGgQlADQICgBoCGuxrnWuxYvfTocRnPLs0G883x8FM11F2fxzv4wa33Z/5WssaOqarCXPdLzXtZWUlU1mOezsw/ejua+fXoe73xzO6sbuvdG3rTUPczaRTYf5E1LXEyPutnfT/9p3vpyc5U182xVNldVtb/M2rIm/bB+q6rGV96I5taDvEVncZL9Jmb9Ubyz+lkzT+fb+VnUe7SXXfrhZ+Od086NaG49eDfemfBGCQANghIAGgQlADQISgBoEJQA0CAoAaBBUAJAg6AEgAZBCQANghIAGuLOsbsf5XVznWvZ7Ad5M1w9Ps52LnfyGqzdg+z/hFEv/39iNQtvqn8Q73w8PY1nz86yqr/5Kq+iWt/N5nZvZrV0VVUP792O5s4fZ9VaXFzXNp5Gc0dbV+Kdz0fzaG5nnf0eqqpe3cyq6T7Rz6rZqqr+9Hl2dtw/y+o0q6qm3ewsPHuan1v9RbZzI7yfqqrB85vR3OpS/r13nj2K5rqTF6jvS/b9QrcBwC8ZQQkADYISABoEJQA0CEoAaBCUANAgKAGgQVACQIOgBICGuJnn0jxrwqiqenwyjuau7y/inTv7J9Hco3Ge/bOnWTvNxm7ejDPrZrPzRfzoa9nPrz/e6kRzw428aWn0WrZzMvp8vHP+x1njzvzt/O+Oi+nKYB3N9UZ5m8pLvWE0d76dN2C9Ps2uvz3L234+vZ81hZ2+wDvL82X2PHem2TOqqtoMb2m52Yt31uSVaGzxJLufn32AH0Rj6+VH+c6AN0oAaBCUANAgKAGgQVACQIOgBIAGQQkADYISABoEJQA0CEoAaBCUANAQ96htrZfx0v1ZVk23O8mri8b9rI5puLEd7zw6yqr2lrPH8c6tw+yeBt0P453Tfn797cqq6fpPXo939s6eRXPr82vxzvm3bmY7R4/inVxMr4enzGAjqzWsquqF/+Lfn+ZVjcerzWjuQT+vsDvrZhWM29f3452rVXbz6+ONeGe3n53ZnY28TnO5He5cPY93dmc/ygaz1s38ur/YdQDwy0VQAkCDoASABkEJAA2CEgAaBCUANAhKAGgQlADQICgBoCFu5jkb7sZLe7tZI8NOP2/iGC6zZqDNYd7ycHCYtb6cPDuLd56ss/89RpdP4p2TRd6wsfHTl6O5wZ9/Mt45vbyK5vp38haUbvfL0dzmwf14JxfT5UEvmpvsxsdRTcKj49o6O4uqqlZb2fV3V3n7161x1mRz81H2G6uqeri8Es09Xc7inYMrT6O56eCDeOd6mrUNbXXz53lW2XMa9PbinQlvlADQICgBoEFQAkCDoASABkEJAA2CEgAaBCUANAhKAGgQlADQICgBoCHujHpvkFUcVVUdL7PapmuLPKeHYcPT7iiry6qqujzO6u7O+zvxzuOzrWju6TKvsHt6klVWVVWt3vn1aG79NK+3Wp5ns6uDy/HOejnb2R1/Pt/JhfTV8Jh5vJWfB6uwme7pfB7vnNezaO5yfxTv3BxltY6XRvnOB2Gj5e/XON75uJed74P1ebyzE9aO1vJmvLNbt6O5eSdeGV4XAPj/EpQA0CAoAaBBUAJAg6AEgAZBCQANghIAGgQlADQISgBoiJt5BnfzpRuTs2hudZbXJyzWWRXHQTdv5rkW/pvQna/jnaOdrHHm2SBv+3my8fV49ry+Es0tPv+78c71W7vZztfyJpDxV34czW18+Jl4JxfTW1dfjeZOvxTWb1XVR9NJNHfpSd5ANersZXNbeavWJxfZEfvKKj9j9oeb2WAvv/dJ/zCa62zk59asm7VqLVevxTuXk6zBqL+RZVDKGyUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA2CEgAaBCUANAhKAGjorNfrvDsJAP6G8UYJAA2CEgAaBCUANAhKAGgQlADQICgBoEFQAkCDoASABkEJAA3/FxQz9KMfad4nAAAAAElFTkSuQmCC", - "text/plain": [ - "<Figure size 640x480 with 4 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAD5CAYAAAC6TTYBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFSElEQVR4nO3YMavWZRzH4fuoU4lgQnWmg6RH3EWiV5C1ODQ5CgotLeHgZI7nCNooKEJE4AsQx8aMbBGcjsIZBUfBNBT+vQAhHfR+hs91rc8fvr/p4cO9tizLMgCArD2rPgAAWC0xAABxYgAA4sQAAMSJAQCIEwMAECcGACBODABA3L53/XDj1taHvOMNHx18MXXv0P5/pu6NMcbTP9en7m1c+mPq3s71k1P3Pv70+dS9McZ4+OVvU/f2fP5o6t77cOyna1P3Xq6/nrp399TPU/fGGOOb33+Yurd59u+pezu3TkzdW4Xdr29O3Xvbf4eXAQCIEwMAECcGACBODABAnBgAgDgxAABxYgAA4sQAAMSJAQCIEwMAECcGACBODABAnBgAgDgxAABxYgAA4sQAAMSJAQCIEwMAECcGACBODABAnBgAgDgxAABxYgAA4sQAAMSJAQCIEwMAECcGACBubVmWZdVHAACr42UAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgTgwAQJwYAIA4MQAAcWIAAOLEAADEiQEAiBMDABAnBgAgbt+7frhxc/tD3vGGvc/2Tt17fOb61L0xxjh859zUvc3z96fufXbvwNS93SvHp+6NMcaFrV+n7p3+4sHUvffhyPbVqXuvDr2eurf77Y2pe2OMsfnL91P3Dl+8N3XvyY9fTd3795Nl6t4YY1z+7vbUvTNH//rf370MAECcGACAODEAAHFiAADixAAAxIkBAIgTAwAQJwYAIE4MAECcGACAODEAAHFiAADixAAAxIkBAIgTAwAQJwYAIE4MAECcGACAODEAAHFiAADixAAAxIkBAIgTAwAQJwYAIE4MAECcGACAODEAAHFry7Isqz4CAFgdLwMAECcGACBODABAnBgAgDgxAABxYgAA4sQAAMSJAQCIEwMAEPcfiCw/6Avfi0AAAAAASUVORK5CYII=", - "text/plain": [ - "<Figure size 640x480 with 4 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ + "import math\n", "from torch import nn\n", "\n", "# define embedding class\n", "class Embedding(nn.Module):\n", - " def __init__(self, patch_size, in_channels, out_channels, return_patches=False):\n", + " def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):\n", " super(Embedding, self).__init__()\n", " self.patch_size = patch_size\n", " self.in_channels = in_channels\n", " self.out_channels = out_channels\n", " self.return_patches = return_patches\n", + " self.classify = extra_token\n", " self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)\n", " self.norm = nn.LayerNorm(out_channels)\n", "\n", @@ -186,19 +69,39 @@ " patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)\n", "\n", " return patches\n", + "\n", + " def get_pos_encoding(self, d_emb, max_len):\n", + " pos = torch.arange(0, max_len).float().unsqueeze(1)\n", + " i = torch.arange(0, d_emb, 2).float()\n", + "\n", + " div = torch.exp(-i * math.log(10000) / d_emb)\n", + "\n", + " sin = torch.sin(pos * div)\n", + " cos = torch.cos(pos * div)\n", + "\n", + " pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n", + "\n", + " return pos_encoding\n", " \n", " def forward(self, x):\n", " # get the patches\n", - " print('patch size: ', self.patch_size)\n", " patches = self.get_patches(x, patch_size=self.patch_size)\n", " # flatten the patches\n", " patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)\n", + " # add extra embedding token if classification is needed\n", + " if self.classify:\n", + " extra_token = torch.rand(1, self.in_channels, self.patch_size, self.patch_size)\n", + " patches = torch.cat((extra_token, patches), dim=0)\n", + " \n", " # get the embedding\n", " embedding = self.conv(patches)\n", " # flatten the embedding\n", " embedding = embedding.reshape(-1, self.out_channels)\n", " # normalize the embedding\n", " embedding = self.norm(embedding)\n", + " # add the positional encoding\n", + " pos_encoding = self.get_pos_encoding(self.out_channels, embedding.shape[0])\n", + " embedding = embedding + pos_encoding\n", " \n", " if self.return_patches:\n", " return embedding, patches\n", @@ -208,7 +111,7 @@ "\n", "# use the embedding class\n", "patch_size = 16\n", - "embedding = Embedding(patch_size=patch_size, in_channels=3, out_channels=8, return_patches=True)\n", + "embedding = Embedding(patch_size=patch_size, in_channels=3, out_channels=8, return_patches=True, extra_token=True)\n", "\n", "embedding(img.unsqueeze(0))\n", "\n", @@ -227,157 +130,474 @@ " ax[i, j].axis('off')\n", "\n", "# plot the embeddings\n", + "# plot the embeddings\n", "no = int(32 / patch_size)\n", "fig, ax = plt.subplots(no, no)\n", "for i in range(no):\n", " for j in range(no):\n", - " ax[i, j].imshow(embedding[i * no + j, :].detach().numpy().reshape(1, -1))\n", + " ax[i, j].imshow(embedding.squeeze()[i * no + j, :].detach().numpy().reshape(1, -1))\n", " ax[i, j].axis('off')\n", "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class SelfAttention(nn.Module):\n", + " def __init__(self, embed_dim):\n", + " super().__init__()\n", + "\n", + " # Query, Key, Value weight matrices\n", + " self.qkv_linear = nn.Linear(embed_dim, embed_dim * 3)\n", + "\n", + " # Final output weight matrix\n", + " self.output_linear = nn.Linear(embed_dim, embed_dim)\n", + " \n", + " def forward(self, x):\n", + " batch_size, seq_len, embed_dim = x.size()\n", + "\n", + " # Create queries, keys, and values\n", + " qkv = self.qkv_linear(x)\n", + " q, k, v = torch.split(qkv, embed_dim, dim=-1)\n", "\n", - " " + " # Compute attention scores\n", + " scores = torch.matmul(q, k.transpose(-2, -1)) / (embed_dim ** 0.5)\n", + " attn = torch.softmax(scores, dim=-1)\n", + "\n", + " # Apply attention to values\n", + " weighted_values = torch.matmul(attn, v)\n", + "\n", + " # Apply final output weight matrix\n", + " output = self.output_linear(weighted_values)\n", + "\n", + " return output\n", + "\n", + "# get the context\n", + "context = SelfAttention(embed_dim=8)\n", + "\n", + "# get the embedding\n", + "embedding = Embedding(patch_size=16, in_channels=3, out_channels=8, return_patches=False, extra_token=True)\n", + "\n", + "# get the embedding\n", + "embedding = embedding(img.unsqueeze(0))\n", + "\n", + "# get the context\n", + "context = context(embedding)\n", + "\n", + "print('embedding: ', embedding.shape)\n", + "print('context: ', context.shape)\n", + "\n", + "\n", + "\n", + "\n", + "\n" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,\n", - " 0.0000e+00, 1.0000e+00],\n", - " [ 8.4147e-01, 5.4030e-01, 6.8156e-01, ..., 1.0000e+00,\n", - " 1.3335e-04, 1.0000e+00],\n", - " [ 9.0930e-01, -4.1615e-01, 9.9748e-01, ..., 1.0000e+00,\n", - " 2.6670e-04, 1.0000e+00],\n", - " ...,\n", - " [ 3.7961e-01, -9.2515e-01, -4.6453e-01, ..., 9.9985e-01,\n", - " 1.2935e-02, 9.9992e-01],\n", - " [-5.7338e-01, -8.1929e-01, -9.4349e-01, ..., 9.9985e-01,\n", - " 1.3068e-02, 9.9991e-01],\n", - " [-9.9921e-01, 3.9821e-02, -9.1628e-01, ..., 9.9985e-01,\n", - " 1.3201e-02, 9.9991e-01]])\n" - ] - }, - { - "data": { - "text/plain": [ - "(-0.5, 63.5, 99.5, -0.5)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAGFCAYAAAASDy0NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA7SElEQVR4nO2de5hddXnv11r7Ontue26ZXCbJJJlMkgbIjVtAQIECIqgVH9HKsVYQqQfQPmI99Txa6+lpCxREpda22nppjZeqFUGtAoJyFRLuJCSQ6ySZZGaSPbc9+7bWOn/Ys36/77udtdcwGZjk9/389Xufd1/WnjzPm997t33f9y1CiJE4r/cDEEJeP2gACDEYGgBCDIYGgBCDoQEgxGBoAAgxGBoAQgyGBoAQg4lHfeGSz98G8itXfjk4L/vOdZPqaumn8t6Z+lw+0+x6pum8l8+kcObusGrBGwAhBkMDQIjB0AAQYjCRYwDfffsXQL5+/znB+eMX/xh03x5tAfniNzwN8tPFYnBevmYf6PoqYyBne44E52FvAnTJrnGQi35ZCR1F0JV9F2QvWwbZ9T2lq8fXSrw6b3JdanKdnwjvu/Ljk+v9WI33hply+1XqyAkPbwCEGAwNACEGE9kF6IyVQH7yjnXB+c5bHwfdkruuBfnBS28H+fpd7wzOH+z6Neh+OLoa5Iu7tgXnZ0p1oFs99yDIfRV17Z/bPgy6Ya8Acn0zykW/EpzjDegeSPfBrlOv1V0Hy7IsKxniAoRc8WvqY6FvtSwnxH0I09VyAabjPtC9mPXwBkCIwdAAEGIwNACEGEzkGMC5938E5OXfUn7/z/4yBbqef0cfetFbG0De/sDS4Hzx1f+J37PlzSB/8aRNwfmu4fWgO6tlJ8jPleYG55XZw6A75KKt62waBXnYUzGOhnoZH8Dfk0irGEDFwviAk0IZYgQiPlAVPwhJ9dVKA4aa8tAUYY24RIgfXzN+EAbjA7MC3gAIMRgaAEIMhgaAEIOJHANYeQv6zJWNpwTnDz+Evvnyh7aA/LM8xggW3qtKehs+mAZd/qk2kNesV775DftWgE6PD1gWxghWNxwA3Y5yB8jdDUdAPuKpRHtrfR50o14F5ExaPZOsEYgnUdZjBE4ivMTYEnUAECMIyeVbVo0YQZi/Xeu/gBoxgjAYI5j98AZAiMHQABBiMJFdAO+VPSAf+o5K5XV/KYEfuqgL5E+/1ANy2xNbg/OzJUy5zdmC123dRRjZgV2Gy9dhd+BvBhcH508s+RnotuS7QV6SGQR5f6UpOHfWobszKu6yjWlVclwQLkAqNXkZcSyOaT/Pwuu1HZ+8jDgsRfjbN7+6UuBplfPSPTju4Q2AEIOhASDEYGgACDGYyDGA/ms3gPzwqarF94p7zwbd3o+dAbJ9P36WX9kVnL95ZCPoGp/pB1mfAtS8HZ2/Ngfbg/fsbw/OS3sxzfeN0bNAvqTtOZD3lVX6cV4aW4mPuJiqzKbVM+XFdvVMUsYAlF+fSFaETpQRh8QI7FhIitCyapT7Tq7ya/nx02oHfh02zzM+MCV4AyDEYGgACDEYGgBCDCZyDOD91/0E5HsnlL8d78Qy29OveBbkA9ctAtlfo0p6f7gNawiW7UXf/KlifXDO7sCxZDEb7VeiLxmcO2Oo23EUn/HauRgjuFcbRTY3hTGAw24jyG0pNY141MNZXfVJfMay5scn42IysYV+fCwmYwBKtmMhNQKWFV4qHOaLz+RIsGkwrToBEhneAAgxGBoAQgyGBoAQg4kcA7ghi+O3Tv7y9cE5+U587Xe6cAz4Fc9incABrU4gsxnfazvo/N0zvCY4p3cOgE5uCmrQlgw12NiCPDSIfvz8ONb77xxXMQ1ZIzCg9QlYlmW1JvQYQBJ09QmMARS0OoFUQtYBoF8fj8txYpPXAVT1EWj6Y1UjYFkR6gTCmKE+gmnB2ALAGwAhBkMDQIjBRHYB3vLS5SB3f/GF4Jy+C6/BAy5edZ00Xscbzj8UnOtvacbXdi8E+d596r2dB3eBblcZ7Vdjn/pemSJ0BvAZW4Xp6xvLBue5nZgG/E1+GcgtCTUxKOdlQNeUwPbmcU99UTouXABxjQ9LAzrONNKArzZFaFmzcrEoU4THDt4ACDEYGgBCDIYGgBCDiRwDKPzdfJDrtam7X13yPdBd9sL7QK47PQvyX/V+PTjf9vz5oBs5bynIozuUwzenWATd5sJikOv2q/ScbLWtGxBjvRwRtxhR24vaYuOg6y9inOKk+r7gnHPDYwBFX5UKp+PYKlwSrcSJmEgDajGCWFUaEGMCtubnyxShjAFAmnAapcAnXCuxgfAGQIjB0AAQYjA0AIQYTOQYQOonT4C87XNnBmeZzy5v6gR58AL8rAvqlK97ywCW9x5ejyPEm7ers5NBf/uRYXyt0z8UnI966IvXHcZnTNnYhlzIqbFfrQ6W8/YXsIz43Oax4CzLhBvjog7AV9+TiYtWYeHmVrUL+3oMQG4WFqXAIbl+O3QzEH3tAAPrC3gDIMRgaAAIMZjILkDxLaeB/P23fz44v2/Hu0DX/qMXQe4SpcK7yuoKLa/1besOg1z/C5WCs+eja7HlEJYYdx5VpcIHKvjTMgNYhiuJ5dTrG0VH4uE8ugBZR5UCb6/MA11jDF2AUU9NLs6INGDBR/ubEmlA3bWSpcAyDRhaKjytzUAzNE1ohmCZ8NTgDYAQg6EBIMRgaAAIMZjIMYDMTftBbo8pf/bwd3Hqb2fhKZA/s+i/QP7r/ouV0IvlvNd04xqhH+xUsYf8SowB5Paj/dJLhbeX54AuNYC+uZyak8wp5zEjUoRH87iBqMlRnzVYbgDdsjSmNcc9Faeoi4mtQZaMAWCcQk/1xYWP74rUqwOlwJOXCVchfOaqaULTKOcNLRWmrz4r4A2AEIOhASDEYGgACDGYyDGAH/feDXLPzz8SnFd9/2XQjbxlDcinJB8D+Ze/Pjk4t6xDZ/DyhldA/l6/ihHkLsdxYZl9FuKo1tsXJrpAFTsyBvKEj2W5qZx2tvHPkh/DeoNmR8UajpSwjmFtPU4qHvVUiXF9HNuZCz5uFZLtwHrEIF5jM5AzhVJgaBeeTkvvTMKJwq8JvAEQYjA0AIQYTGQX4PNHsfNu1S0jwdnL4RTdo3+I1+3fFDH91XWfuurufYtI5cXqQfa11N5wL16RO57A98Ya1HtfGGnBH5AbQdHDlFsqp67YcqKwN45pwXotJZcTLkCjg+nGA2X1HBlHLg3Bz03LNKF2061yD3xZCqwvBhGdgiFX5tAUoWVN77o9CycKE4Q3AEIMhgaAEIOhASDEYCLHAL755UtA7typtnoWLsK03zfW/z3In9r9dpAzj6oxP2v/fPJWYcvCduHO5YOgq/8hTuu15rQFx1eOiFbhUdwqNOgK/zuHPrZObBTTdRktr5YrYJlwo4NpwGFXTVPOxDAGkPfwGdOiFLiktQsnqtqBEb0duLoUOCSFOJ2W3uNwci/bhRHeAAgxGBoAQgyGBoAQg4kcA+j8xydB7r/u1OA8dib6vack0Wfe9ctukBcNPxqcr52P48M2DW8A2V4wNzhfMn8r6B7fvwrk8rxscM4N4jPIrUL7XYwfJHMqBy9bYhNj6DimbfXZowX04+sd/J7hiooRLEgdBZ2sA0g5cnuwFgOQW4NErl+2C+tUlwJ7mm5qpcDHaqtQTeirvybwBkCIwdAAEGIwNACEGEzkGICzvBvk91/3k+C8Nr0XdH85sBbkhfflQY71LgvO56YfB91Nz54KcscyVQfw5qZnQPf4AI4IGzu1IzgnBsKdyN2lDpBjw6qGv2Khv53A0gQrocUA8gWsY6i3sZ5/pKLagZfXYR3AuKgDkCPBylodQMyWI8GQmN4LEDIurIqwVmHLmrFcf83NwmGcYLUJrye8ARBiMDQAhBhMZBdg2ydw+u2PszuDs2yfve7fzgF5yeYtIPf/8brgLKfveJsxPXe0V51PSuD1zh3GFt+xLvUcaVwwZNlx/J7dhXaQndHx4Fzw8SqeGBMpN0u5AOUJTOVlRCpvtKxcgLQt2oE9fG+dI5eHqu9JijSgTPrF9FJgX7oAky8WrdUOHLpYlKk6xXH6t+ANgBCDoQEgxGBoAAgxmMgxgF+96Qsgv+WlK4PzRXOwRHfpt3E7jueiD1q4YDQ4P11Cn3nOFkyj6SPDMg6m3CwP/eLxLiXLcWFyC/Gucfwsf1ylKkc9mQZEP1mPefgFLDlOizTUWEWl+uqFj3+gjOPPZClwSYs1JGWZsHDd9RiATAOG+fE1XdfQ7T4zlMqLoifHBN4ACDEYGgBCDIYGgBCDiRwDOOSiz1y4VY26+oc34nbgpVsfBdk+7WSQP32y2jL01cFzQVf/3AGQ531E+dCD7jjo7AQ+U6ZL1exmfoI+v93cBHLfaBrklrz63mEP/frk+OSttk4ebWhaONyjJT0GgK3CeQ+fX8YA9DqBpCPagYWTHAvxx2UdgM5U24Ej6yzruCvLNXFcGG8AhBgMDQAhBhPZBbjyBzeCvOynauHnsv7V+OJ1KPdd0AjyFQ1quu8nH30H6Jbvx7LhN89X04a2FLOgc7JYNnxKp7rGHxmYDzqvDV2AI8NYhttcUNfzIQ8n/SZGJ58YHJ/Ae2NC3IvHS+qanxadgqMuuiHNceyaLGtpwLgjNwOJrUgwFRiR7gFMDQ7TWRFchNcDA6/qMwVvAIQYDA0AIQZDA0CIwUSOAay4Haf+jL/ltOCcuucJ0L1y60aQ557cD/Kwp6bvNG3BqTh2HH3zNzeqiUFfHzobdH5nG8hntzwcnH8yhHGHiR6cAFQexlSfXlY8UMF4QXwcfXd9Mm5MxABke3OhrGQZA5gQ24k6EzINqOIHqRppQH0q8FS2A4dOC6pF2MTg36GPrCOvGbwBEGIwNACEGAwNACEGEzkG4Odx+0/Dx/uCs7erF3TXX/pTkOXU4DuGzgzOczZj7ttZimXFqzQ3+f69+D0dXVjue1qdGlN2z3AP6CY6sC4gkZvcCe2vYH2BM4ZtvPrU4Dj+WayYKAUuFtUPSNvo44+5YiqwgzECfXNQolYdgK23A4tnEn6+3i4cOvLLssKnBs9gjcCrnhp8gpUnzzS8ARBiMDQAhBhMZBdg1424iPP55XcG59Xv/p+g+2jL7tDP+uNfqDThqhd3gC538QqQ9SlAxe2Ynhvtws9dHldXaG8MOwfzc9DWJXPioRyVFjxYyqIqXwC57KtLdky4APrEYMuyrEpJySkbL+cyDZi05dSfyScCyTSgPjVYlgLLNKA+NbhWqW9NFyH0zdN474nELP478AZAiMHQABBiMDQAhBhM5BjAzVd9DeTbjy4Pzle+7Vege2AC7crCOG7wmX+fcorco0dBN7AeHSZ9ClB2Oz7T2CJ8bUtMpQX9stjC04G+bn0fiJaTVP74wQK2A1sTGAPI+3oacPKJwZZlWV5B/YmTYsFnvoITgWSpcM5Vv0dOC9LjA5ZlWXEtvuCGTAyWyBiALCMOS5vVTiEy5Tbb4Q2AEIOhASDEYGgACDGYyDGA8+uOgPwXd7w/OD/5yTtB1/vLq0E+t+dlkLO/3h2cvUZs2+1ah1OBnyy2qvftQF988GzR0htCqQN96Lbn8b12vfK3DxWwRNcv4PcWND85MVEjj15SNjYlfOaqGIAsBa5opcC2LAXGf7r4FEaC6aXAtdqBQ938mfTx2Ur8msAbACEGQwNAiMHQABBiMJFjABuf+ADIC/7pyeD89E3oXy/8Fn7sQ288CeSl/WpzkH0q6q5ZhK3Edx9dG5yTuw6Dbv4C9NX1mgG5Nai+A9uOUzmxOUjbHjyYF1uDisMgj2qbg+KFyXPslmVZdlE5rEmROJ+oYC+ArAMoepPHAEpVdQB6OzB+jxOay6/VCxC2HTj0rTX8+OOvRuBE3BzEGwAhBkMDQIjBRHYB5t2C11Wnd0lwvupJbOFd/POnQO5y1+KXLl4YnPvXYxrw0sw+kP/6+UuC88JD2Dp8Wgdev3dqaTWnoR50S9uGQC4M40/3m9Trh8fRfciWsKx4VJvWG5sIdwFiJXVvlFuDChV8hoRoB9aXh2YcfIayaDvWNwe54q4qS4F1Sd5q5WagWpuDyPENbwCEGAwNACEGQwNAiMFEjgHYjz0L8ravrQvOCzaJ+EBbK8h1j2If7+Db1fbgI+sxvaW39FqWZZV3qDFgssX3rMZdID8xsVQ9bxbHh61q2gPyc8Oor2RVC3BxXMQHXHzGnKeeMZ6ffHOwZVmWo6UBE6JVuFjG75FpwIKWBmwWs8dkKXDC1kuBw9OAestvWIrQssJbfmtuDtbeO6WtQVH0x/p9hsIbACEGQwNAiMHQABBiMJFjALmrzgT54TfdGpw/8KELQdf/x+tA7vjyYyAPXFAMzmcv3wm6rSUs2W1+SZ0d0Tq8NoWtw58ZWB+c3dYG0K2uwxlgz48sALnUrW0DGhN2UYzJ0kd1xQpiXLfwdWPqp1qOsLelqjoAjCcUvfikuqpSYL0OQG4OtifP3TuiRkBvFbasaZYCzxCvemsQqYI3AEIMhgaAEIOJ7AKc8ZEnQd6nLba009iVN/ddmHKzHlwO4o2n3h+c5eLQf8+dAXJ2h0p/2fM7QdcVw/Tjs4fUAtCODkwnrkwdBNkfR1ej2Kx1+I2F322HXOVe2BOYuquI1Zy6CyAXh5bLeI1PijJbfXOQXBwq04D61B+5OLRq8q92dmpc46eyGcizok8Uft04wToUpwtvAIQYDA0AIQZDA0CIwUSOAXxh/hMg92y6ITi3vh0dq7t7/g7ki8//OMjXNG8Lzgkb/eAPPYkpxB5tCtDYOlwHrG8OtizLyvcp3zzfIbYT6c64ZVn+BJbWFrPqN8RHhaMoHOGjFdU6bBdFm66PMQC9i1emASsiBpAQ6bqSp28VmsJmIJkGDGsHnsZEIG4OPka8jn8L3gAIMRgaAEIMhgaAEIOJHAP4UN9GkFfcrvL3w1/FKbrSqlQuyAm98zvPlmVZiS1YwuseejE4H+3tBp30t+v7lF880YHP0BrDWgWvJCbwajGAxDioLDuO9QaDmu9uF2QMQJYCa623wtnzyuK325PXAchS4HEPf4+u90QdQGgpsBz55YeXAsP24Gm0EpPZAW8AhBgMDQAhBhPZBXj2jjUgt+RVm943V30XdNfsvALkT62+B+R/HVkWnNemsWx4zmaRrquo9NdoL6bCDrqYymvYp666Q2vw/pmy8RpveaK7Lquus3UD+FI7ie89UlIugF+S03rxWoylwMLeShdAvLfkqu+RE4PLPpY666m+kpgY7AgXwNVv8dOZCBT6zhoYWHY7G+ENgBCDoQEgxGBoAAgxmMgxgKZNj4O869MqLdgRw4Wee763DOQr/hz1vb98R3A+t+dl0NU9j5N7PG0KUPeyQ6B7sdQGckOfcrj7L0Q/uBaVrPKxk2P4XtnufKSoyaUC6AoijRYvhpTSVqUBUV/Q0oBJkQaUpcChacCQUuBYVavwFCYC1WKm/PzQlt6Z+coTFd4ACDEYGgBCDIYGgBCDiRwD8M/COoCbr/pacL6h7/dBN/8/XgF56BOYr2+7V5UOPzCxEnS9/dh2HFulxoldOm8L6B4d7wE5cTAXnLPt6LePeeir23H86cmsih8kxrC02U6jnCsoubE0ArpxD21qrBQWAxBbfIW+6OpTgbEOoOhhbUIipB24ajOQpq+9GSisHXgGc/msE3hN4A2AEIOhASDEYGgACDGYyDGAgT9DH/r8uiPB+X//x8mgWzD4G5BvGzwb5I4H9gfnSh1u6HGEvz3e0xKcL254AXQff+WdIMcODwXnZa04LuyQiz60nUR9e/NYcE6Mie3AdRhPGC0ouaEifHORn48VQ1pxS+irJ+TYcL0XwIpeByBHhst2YNefPAYw+dP+f732iilsB56S7jjEP05/D28AhBgMDQAhBhPZBXh4wzdAPuOJDwTnRd/bB7r8BWtB/u6jaGd6dysXoWNzM+isnm4Qj/aqR1yRwGvvywdx7E/PqHItVjdhm+7uCn6PnakDeWFjLjjnxrDV1q9Ht2SioKXgXDGpx8f0nBPiAojMnhUT92K9HbhWKXBGGz9cnQYMeYZaacBQLeLVdCDIbIM3AEIMhgaAEIOhASDEYCLHAH6abwd5/s3qrZW92MK757Pomy+4S6TGOjT9NhwJNvS21SCPLleOshzrFetD39zSWnFX1+EzbSvOB9luqAd5cUZNOR4Rv7WSxXhBpaB+uy9iAKMevjY0DViWaUC0x6XK5CPBCqIUuDmmyq1lGlBOG/ZCSoFd0c4cVipca+pvaHxBvNcV05RnLIV4nKbrZgreAAgxGBoAQgyGBoAQg4kcA/jU168CeeHjjwZn79y1oPvy2d8E+Y4/exPIo+csDc6ZH+CoscEN6Dcu7ekPzgcrY6Cr34cOnZ1SJbq9icOge2B4FcheE+b6l6TULPDn84tAV+lqAtmf0Oym8JlHPLElqay16Qo/18FShaotSXopcEyM6qp4cpyY/j2oC2sHDtsaZFm/YzOQ9hw124Hpb896eAMgxGBoAAgxmMguQPcXnwd56Kozg/PhszEVdlEGF2/edgiv430XdgfnVfdjie6KNXtBvrBjW3B+vDgXdI19mBqLtWSD8/w46rbmOkGON2O6rjs5GJz9Cex8LDeINGZ+8rvtqIuf6xTUc3jiGu/IUmCRV6tAGhD/xnIikF7uW9Up6OB7Pa11rfqKjxyrqT/yt88U/rSmGNfSn3hTingDIMRgaAAIMRgaAEIMJnIMwG7GVNjGj6rpvWc17gDd3+cW4pd0Y1rtXRtVO/BTq3Da8Ie7vg/y0oTyzT/ffyHoMn3jIHtz1aagFgfTcfsHsyDPy6IPPT8+rAQRAyjVi0m/qAZGRRrQ1tKAsl3WwVBJVRrQrShZlvOW/ZA0oPgcx3r1U4GdsO3ANUuBw5RMIc4GeAMgxGBoAAgxGBoAQgwmcgxg601dIN8z7+5JX9uz6T0gt74RHbqPtW8KzhevPwt0F9YNgpywVU77oT1LQdfTj68dX6smDOvvsyzLqgygb17Iou3r0JLyfglrdMsNIGIdgHB0x1zRolxSjn7Zxyy7jAFIXFc9oywFLnlis5HWLizrAORIsLAYgCwMft02A5HfMsOxEN4ACDEYGgBCDCayC/CDt34e5A/1vTE4n9u8HXQ9m7Brb8dHMOXW5qhy2dwGvG7X2biwI6ZPydmBU3y8IfzekYWLg7PsvEsP4LW4mAXRanTUn8Ir4d283Ij3sLiWBrRj+LnDFSwFtsvqau6Ka7xcHOqI+54fkgbUF4f+9r16KbCcCCRLgdXn1k4DChfBn7wbUE4TCkv11UohktcG3gAIMRgaAEIMhgaAEIOJHANojaFf/OznVAnvz8/B5aC9T+Jy0I+uPwry98fVws8LV28F3QtljAno6bksuvyWV8Ca3HEtUznmF0FXdxj90Yk56IRC7METW3gw9GDVHVJnO45/wpEKxjussp6eCy8FjompwH5FPWPVRCBRCqxvDsp7uMw0bCJQVZmwdOOZ6juh4Q2AEIOhASDEYGgACDGYyDGA837xUZB7Nz0WnBv2rQVdbEUPyFc3fwvkM554f3D+5zU4QfjrQ1gafFbjy8G5eUceH0okk8sLld9/yEV/u24Q5eHl+F7pf+tUGvC98d3aexPo84+W8U/qQymwyPvXKAW2tDoA+XQlWQdgR68D0NuFwzYH/1Y/jc1A4eoab36VsYeaY71e3ceeqPAGQIjB0AAQYjA0AIQYTOQYwKpbciCX37A2ODsPPwO6XZ/aCLIchRW/LxucN5yO3/PerTgi7EiPSsIndvWDzmvAPt3F84eC885yK+jqDmN9gduCNfxhuA3oQycm1HvtJMYARsqYg7cqqlZBuvyxcrifa7vKYU0I37Xk4fMnYSQYvri6nl/Z/bgjNweLZxDylDYDhcH6glkBbwCEGAwNACEGE9kF8HbvA3no1u7gPO9GnAL81j94BOSvDK8Eed79apKP9+d46azfgu20D8eXBOclh18AXXwxTik6vX1PcH6xsAB0iSGcIFzXjG3HeU9zERyxCagBV/jEJ7TNOiINOF7Gz01rC00LvriaV2pcg6EUGCkLFwDbgcXzT2EikHQfpsOMlRHTfThm8AZAiMHQABBiMDQAhBhM5BjAwes2gPzo+juC8+lXfhR0/znnhyCvuO+DIC/f+lRwfnAiA7o5W7DFty+r9eLKNt15WZDPbvhFcP7R0HrQ2UeGQe5sxum9Y75K0tkJ/LPU1+MzxbU0oJVCn3+8JGIAZfW5BTmtt0YaUN8eXLU52Ju8Hbjs1SgF1keC1djaG14KLCcKh5cV45unqT+O8Gfxb+ENgBCDoQEgxGBoAAgxmMgxgGuuvQfku/MdwXnjO7AUeGsZi1477sXy2FhjY3D+Sv+5oEs+vwfkbNdy9b5sM+iGu7BmYHXycHC+OTcHdE2jh0Be2DAB8pBWdmsn0Y9vrcc25NiEih/4IgYwUcK6AN/V8/Nia2+pRilwRR/dhZTd6HUAVWPBrcnHgruyViFkc1BN15b5+lkPbwCEGAwNACEGE9kF+HB2F8gnf+n64Pzkn9wBujOe+ADIix7sAzm/sTc4b30ObVDvIE4Ubn5ZlfT6XfNAN7oQ39sVV65G/yC6Cw0T6FosyWBq74Cr3BJbXOs76nDT0UihST1THb62WBR/Um0ScEFM6nHK4Wkz/eYes6KnAT1LTjsKKwWu8QzH6Bo/pRQhec3gDYAQg6EBIMRgaAAIMZjIMYBLtr0N5O4vPh+cD16L03aaNjWBXNnzIsh9H1ZtvG2b8XucelzDY2tTgEbPWgK68S70K1O2loI7LCbziDLiJakBkPeV29R31mF6cW4aU4ijEyoF6mUw7VcpianArvrecbmxp0Y7sF4K7MgJyK5IKepTgUWrcCI2eSlwrCrNF54GjKqzrPCpwTVjC9p75abn0PzjdMpuZ3HJ7kzBGwAhBkMDQIjB0AAQYjCRYwCVW+eCnGxWPtxVL/4R6LJ3PwuyM7cT5I3nqtFeh7+N48T83m6QvWe2BeejvbhxKNk1CnJRa+lNDwjbJsZ8LUwMgfyrMTW2zM9gq/CCVA7kl4sq5uG2YzuzXxTDu7RtQAUf4wVOSfrm6OvaMBJM1AGIUuBYSCmwzPVDKbDcDlwjBuD6x2gqMJkV8AZAiMHQABBiMJFdgOTPngB56+fPDM4dm/DaaMcxxZY7pxvkL8y/LThfv/1i0A1eeQrIrU9pabTlmG7cMPcgyH0VtRy0bkAs4qzDa/38GLoPu/IqDehlMF3XmcBpQnZBPUelDq/bdmlym5qXacAapcCQBhS22hWlwHrHX9VUYHnN9yfvBqx6hhoTg+BzxWtDXYTXKeXm020BeAMgxGBoAAgxGBoAQgwmcgygcDlu8fzPt90RnP/Xpy4C3dDbVoN8+GxMdy1LqKWe3jhu7BnagH5xxw9UW++qZQdAd1bLTpCfK6lUZWYAv9NpxEWibTH0BfeNtQTneAO2+M6NYwzAL2lpwDq0oXZxcudWpgGtikzPCR9abwcWdbWuKAXW/fyK3Bok0oB6jKAqRSgmAkk/3g3RTQf522cloSXIx8Hz/w54AyDEYGgACDEYGgBCDCZyDKD5pr0gZ7UktV/B7bn+uwdBfu+CbSA/MKG1o7a3gW79Ka+APL5IjQG7rPNh0K1N45ivu4bVNqD0QBF0fqsYEWajP354RMUIOhrwz9Iaw5FgVkmVHFfSYvxWSAxAtgPbZTmtF/1xB/+sgOuK79V8aDl9uGozkGb3EyHjwiyrVjvw5M9nWeHtwDU5Tn3q4w3eAAgxGBoAQgyGBoAQg4kcA7hr+c9AXvLTjwbnBZdi3nnTSbeB3CycxXe8eFVwTq3Ogu6aed8C+f/2vD84n595CXQyl//4QHdwzgyi317uxDFlGQdz/fmcGgNWbkC72OaIeIJWB1Cuix4DyHv4nXYZnXy91dayLMsO6QXwPDkSbPI6gJiILejbf2qNBQ9jOnUA04sPTOO9BOANgBCDoQEgxGAiuwA3Dy0HedUtueC8928wvTUvhlfdhI1X0iO/VKm92HpQWeemsU33pl713p6EaKcVd8G9B1RKceVRTFtOnNxuhRHLqT9FqQE/t1G4MHra0xUugPAW4K475mJLsuVGTwPK3+pPIQ0oNwOFLgetkQacisPAiUGvAdN0h3gDIMRgaAAIMRgaAEIMJnIM4DtfuhDkzj1PBed/W/cg6P70wBtBflMzlgJ33af8/B034iOkbJT1MWAyliBJ9KnYgz+KsYR8hxipJSbwJnNKX8bOYSsjvtfX0ncVXCJkxXBqmWXH1HvzrkgDVmSJrhhjFrI5yHdlua/6PVXtwJZsB1Z/Y5ki9PzJ04uSqviATGNqepnirFXqO600IYkMbwCEGAwNACEGQwNAiMFEjgHM+Qqu8T3w4VOD8+oktuk+/IN1ID902lKQFz6tSnov/z30QTeLbTmrl+0PznsrWN7bLLb9NOxTZ6+ICflCBzqVEz4668mcOssYQEq0Dlta/EDGANJHULa0GMC4K/7cIgZQlpuBYCQY2uqwOoCK8OOToh1YH09esw6gaqS49nzM8x/38AZAiMHQABBiMJFdAHsFXuOv/dCPg/NfHMYr/+Lv7Af5wMgC/KyYsjvvb/016P5h4I0gXzZHLRp9ZAIXia5M9oPc2KfVzoq0U6EDr8GjHnbipXLq9RO4B7U6/ah9tluH3xMTpcB6GnBCLPS0xCQlWWYbNhHIcuVVXfvYqk5BOfVnCpuBptUt+KrfWuOD6XocK3gDIMRgaAAIMRgaAEIMJnIMYPsnMN91XbOayHvSP18OukW7HgF53i+xBLayYWVwPjmJr/35C78H8sfOvzc4f6bvMtCNtOKEoLo+lSb0RIow3lEAecDDn57Oadt1G6LbRRkDcEpyU7L6nvEKphPlNOWSH70U2PJEGlATKzW2A+vlvlXTgqzwUmAvRDcVaoYHZsrPD93uMzNfOZvhDYAQg6EBIMRgaAAIMZjIMYCHzvsiyJdse09wXvKdAdBVzl4LsvvIMyDv+9TG4CxLcpu34NivJRepMVq/2d0NulQMfWjnkFaHW58B3YL2HMgHKrgpKJlT2378+vC2Yx2vDusL4kXx3oT6E+crGAuxXIxLTKUOwJ5CHYAcCaaX+1a39EZvB5Yus2xnZqnw7Ic3AEIMhgaAEIOJ7ALsEW1vlVtVvayz9QnQvfz3Z4C88gVcyrH4jSqFeP9EK+jmbB4HWZ+Gm3gJn+HpBiwx7jy6KzjLpaOrsodA3l3CKcHx4Qn1PQ14VS/7eM3Xa1xt4QLESsKmamnACZEGTLn4W8tyaI4bcoUWj6Q7Hq4sBa6aCKReHeYezCQz5h7Q7ZgSvAEQYjA0AIQYDA0AIQYTOQZw1fevB3nZzx4NzvaG1aD7zAXfB/lf7v4DkG9Z8qXg/Nk9b8UH2oobffa7+eCc3YH+av9cjC3M0aYAeS2oW13/HMi7ih0gO6MqBtBYj9+TF6lKvcU3VVcGXayAaUA7ofz+QgX/3EmxGagkU3BTSAPGtLiEnAhUtRlI08uJP5KwiUG1SoGZBoyG/zqWIPMGQIjB0AAQYjA0AIQYTOQYwIrbd4E8dvnpwXn/eWhH3tc0CPJnL0C/+PSU8otffGYx6JYffQzkX08offMOnAo80o1+vqW1AJc7sBR4ZeoAfu5R3Hbsj6lYQ3tGTtFFWW/xzaRxBlisKMp9tVLgQhn/3E0iBlCWrbgh7cBi0C8g6wBkO7CrtwPLOgAZhwiJEUylHVhuPq4J23ZfE3gDIMRgaAAIMRgaAEIMJnIMwC9hLrzlJlXP/8523P77/TH0zc97w/MgP1tSbbDtW0Q+uwnf+6OBtcHZ2X0QdA19jfjehvrgPNaBvnh3fBjkfaNZkJvyqlegM4N+fU741JaW228SMQCnVA+yr8UASqIOwHdljT5+jx0aA5jCWPCQXoCECCZ4NUaCuSFJa+nlh7rqrBGYFfAGQIjB0AAQYjCRXYCdN64EedvyL03ySsta8uMPgnz/JZ8D+WN7VGlw21NHQeetwLTg5l3qEXuGngZdY59YwzNHtQDnO9C2dcbwpw4O4wbQxqIqQZ6XxhrcI14aZDupXICWVB5040V0YXztteWy3DCEl+aCj8/ohLQDy4U9Me3CLa/pVaXA2mtrtQOHbQY6lqW+U04TkmMCbwCEGAwNACEGQwNAiMFEjgHc8d6vgnzzkCqlPTWzE3TLNmFqacnl6G+/8GBPcF66/SnQDfwRbhpObZ88mZTcj6m98rxscC5gt6/V4KAfX8rh9GF9S8+8ZA6fyUW/3k6qFGNLEsuTJ0r42/2U8vsrFYwB+B760GURAwhPA6Ls6O3AbvhmIL3ct2pzcFUrsdwMFL0dOHSicI1y3lnZSnwCliDzBkCIwdAAEGIwNACEGEzkGMDGdA7kT972geD8jxsxH7/8QfTr75tAn3ThfdrorgJuxzmyAXPw8+7XRli3tuBDDRwBcexU5fgX5oTM07IsK56b/KfPFWXD/WXcImSlVAygNYmjvQ+W8HvdZjXK3CvVqgPAseGOG5KDdyd3SF0vPJcPY8FrjASbCq7/Gvnt4qe7+t+xlp9+Avrx04E3AEIMhgaAEIOJ7AKc9fg1IC/8yubg3LR7DX5oF27s+fSOJSA3/UZ1D9qdc0B3zskvgXz4HxcGZ3/hXNB5z+Jrx7qUPUt34NW86OP03mRO3AW1vFRHfARUWwvzQfY1F6A9gWlAq4wugKelAf2ysLfiylzyxUThsrrausJdkGlAKAWumggkpwLrqTxZJhw+EYhTgU8seAMgxGBoAAgxGBoAQgwmcgyg62YxrWbF0uCc/gWm/fZ95HSQ/fvxsxoKappQZeMq0F3TiSXHf7vnguA8fD5O8m14Gh3h8YVKXt6KbcZHXExVpnLonzopVRrc5mCLb79s8U1r7cBxjDXYJYw1eEnt71YJz0HJNKBdCUsDouxottyT7cCWLDmePA1YNRU4xI+vtVUIPle8tmZ84HVI1/kGxix4AyDEYGgACDEYGgBCDCZyDMB6Aif7bv/G2uC84qZW0J1yxYsgD96AdQH2KSuC8+EN2Ja7MYXOrTukyn2PLkd71ZjAyb/1XaPBeVW2H3T7XPye1LDwSetUyW5rDP34QxMYA/DqlK/eFhN1ACIG4GoxALsUbm8LnogBaKXAnvShQyZoeTVKgcPqAGqNBJPxhTBYBzD74Q2AEIOhASDEYCK7AEffdybID513a3C+9Io/A913F90O8jufeQPI+7U0YXEdptHCKPRi56DTmgV5TadaAHpypg9024rzQE7lxMLPerVMtNFGuzgwgcs+6tLqz5aN4fP7FeHCpLXrNnoHVVRNBAqZCuzIUmCtlNmrsRzU05eDVpUJh08EgueTS0MmfeXvem/010r3Z1YuFQldZjoLn/e/4Q2AEIOhASDEYGgACDGYyDGAc258HOQ9FZU2W/CuXaA74mFLrB3Hr0mfPxCcL5uPLb2PFsVEW20K0Pqle0E33tEO8sbsw8F5ZeoA6O4aXg9yMofLTv1GFQNI2fi8ufE6kBN16hkbHYxLWBUxESipxwCmVgpsVfQ0YHg7MHynW6sUWJ8KLP34Y9cO7HD6zqyHNwBCDIYGgBCDoQEgxGAixwBum7cF5J5/vz44P/LuvwPdu7ZeBXLiNJzm+9kV3wrOi+PYtvs3By8B2Vuk8veXdfwCdF/vehvIp9WpDUVzY9j+u3UEx4nFh9F316f3yhjAxBiWEbt1ym422pjc90UMoJKKXgdQrCoFVo6+nLgb1g7se+EtvWF1AHKz8OvFVOoE8I3H9DFOeHgDIMRgaAAIMRgaAEIMJnIM4Oq9WM+/4naV+3feja8d/zbW3Y9dgI7ZJRndP8+A7tfPrQB5Xq/KuZ9Xh1uI7+zCx18eV052ysZW4Z1DbSAvHsHtP6WFquU3JnoBrDH0zct16vc0OqIvtyxHgqlzVR2AcHTzHj6zFboZaFKV5dcYCVbRYwC1xoKHtAtXxRbkM06hj4C8PvAGQIjB0AAQYjCRXYCtd5wEcra0PTh/YOcVoOv44TaQO3+YBnlvRU3RaXaw9Ld1syjDXa6unIvi6C6MLcSrbksM9TrjA6jz8wdBLmbF4k6N2BjaSa0K2kqLa7wvru1uWBpQuBpVE4G01uJapcCOdjUXS4Sqy3mnkOoLnQo8hWu891otDiVTgjcAQgyGBoAQg6EBIMRgIscAGr/9GMg7//Ks4Fz3XXxtZ34zyH+9+Gcg/+2hC4PzuU3YDtyxBafs7rhx8kcsdpUm1UmSA/g5/jhu/yk2a9t1hROdGEOfGWMAGDvwXTESTKsiljEAW/TLFj3xWz1tO7Dw452QcWG+Gz7Wq+KpZ3ZEbEGOJas1MiyMsBhB1TgxGSMIfW/kRyA14A2AEIOhASDEYGgACDGYyDEA75x1IN/xXrXF9wtnnwe60UvXgrw2hePE/ushpX/p5DmgS720B+QNS1RufHcF/fb5C46APOiqEd0ZG3Pq6QF0HD2xwafYovQVMeA6IZb/uFoMICW+RybhXa26NzmKL5V1AEWx0cfS6wBkO3DIZiBffo7A02oGZHygMI3twCFhCTJL4Q2AEIOhASDEYCK7ALlP4AacjelccL5jCK/iA+/pBPnpIk7n6bpP3V/3uPNBt2wEJ/9e1q42/DyQ7wHdaR342pfLquR4fnwCdJnD4s7s4TW/lFX317yH7kFiDO+24/O08l45gkZc1d20kh38M1l2DO3vhCvKkfWJQHI5aNVmIO2zxE+t6gb0Jp8IVGv5p1weGgY7/mY/vAEQYjA0AIQYDA0AIQYTOQbw4Np/A/nUxz4YnDsvxM053zjtSyB/ei9O761/eEdw7mjCCUCxFpwgfE7dQ8H5T3dj2/EfzsPy5M2F7uBcSu8DXd1A+EjeclY51aNVpcDCr9c6i6umBwlwIpBQxtDnL3qirNjTNwMhYROBLJEGlE+o+/nVE3/C04C6Xr53KtQq533V1b6MO0wJ3gAIMRgaAEIMhgaAEIOJHAO4axxz+4tuVl7a9hvQId2QRF92+wNLQV589JHg3PLUEOgqqxbh92hjwJ7Z2QW6W7pxA/BnBtQG4HQrOtypQawL8MQoskSzqlXIibbc5Bj6upW66HbTS2l1AGIqsC1iAAVXlBVrdQAlP3o7sIwBxGSFsa++V9YIyHZfWScAz1BjKvBURoZNiRn73Gnqj0N4AyDEYGgACDGYyC7A//nae0DuevLR4Pzls54H3V8N4gThhffj9TvWuyw4ezuxnPfwtRsmfYbMdlzS2fX7eGV+9pAqK56TxtY7Z2gEPyyNn9XarEqdj7g4QTgxhgs//fTkE4QlXkpdjGNlYW+FC1ByxT+HqyYeTS0NWOOZfL0bcPLFH5YVvlh0KshSZjI74A2AEIOhASDEYGgACDGYyDGARXc+B/LQ+84MzhdlngLdDd/GCUFLn3gG5EP/Y01wbv8nXPg5vB5bh1+pqPhByw50fDMOLtPM9zUE563Nc0EXG8YYgJ3B8uWuxlxwPuw24nvHMaXoZJQ/KycIV5FU+qrloKIduCBiAI5f0L4H3xpWhWvXKAXW24HlVOBjNfXXssKzZmwVnh3wBkCIwdAAEGIwNACEGEzkGIDTmgX5nBvVpN8vHl0MumWbcPaVV8INPuMXqDG7c++eB7qLVr8I8gP55cG5cQf68WUfYwL1fSqvvncethUvzveDHGvB712c2R+cD5TxvbFxfP6ktuxYThC2RImxk1J6WQdgx/HPXxQxgLS2abgs8/OhpcAoyqoFP2Tsl6wDCKsTqG4VPna1sjMWI2DsAeANgBCDoQEgxGBoAAgxmMgxgK034fjue+b9ODj3fOs60C17AUd12RtWg/zJU34anP/15D8A3dUdPwL5r/Zerj5nL/rxB13sMWjYp/zV/m7M8/sVrOf3GutBXlo3EJz3Ftvw+fMFkDPaqO+Cj58rN/4mkkrvlMSfW8QASmIseFrfDiz89rBeANuVfrzYbqzl+qWPL+sA5PafqTCV7cDkd1NjSvu04Q2AEIOhASDEYCK7AD++/A6Qr977+8G559u4PdNeswrkvvObQb6y8WBwvnk9tvTKaULPvaKmAPXmNoPuxRJe1Rv6VBlx+jC6ABK3OQ1yd1K5AE+OdOOLJ9AFaNLeWpSlwKLFN6W7ABVx7Y3ja8shm4HKwlbbXsj1ukZ1sgubgeSCTzERSKYBfb2M+Nhd471aPcw6oTXG034Uo+ANgBCDoQEgxGBoAAgxmMgxgIzIO239nBr71fgkpv1e+duNILevOQQybN/dMAy6okirZXZoLb9iC8+j47gtOHEwF5zTAxgDsFM4AqyYxdjDgph6joP5JtA5BYxxtGhjvsaFLy4n/WZSqozYKYvUpIwBVMRmIG0SsNzaWxVP0JnKZiDZDlzDifZgM5CMH0w+TmzKO4SYJnxN4A2AEIOhASDEYGgACDGYyDGA83/+pyD3fkf5/bEV6Itfe9nPQV6TxtHfXzhyanB+7/InQferAo7jatmutdO2tYLusUH062OH1ZahzABuMnIaG0AuZNHfbo+puMRQHsuE2wu4vaglqW8Sxj+hbPFtSKoYgF0KrxmoqgPQ4gslaavDRoKFjQy30FcPy/NblmUlan3Yq4Tp+tkBbwCEGAwNACEGE9kFWHUrTvkpn7MuOO+9ANNbP2l9JfSzPnT/2cH53os/B7pP7H07yI07VHrO68ZJvy8fxMfvGVVTfeoGcIqPlcXUXjGLti/rqM8aGccy4TbRSdiaVCXHox5OJrYS+Ez1cfUcE2W8TvvitZWKsMdamXHZF5OGQiYCyQxaTFy4ZUoRdFOaCBSe3DtWHX9TKhMmU4I3AEIMhgaAEIOhASDEYCLHAPy9+0HO3a580rcuwM1Av8LuWWthDEtp592v3rvsMkzPbXl2Gcgr925T33kpthnH9glfViudTQzkQeW24vcUs/jWOlv58qUx9OvlNKH2hNoknPNwk7Cc8tOUVFOLCmX0Zf0E+vWuLAXW0oCubAcOiwEIl9kRE4Gm0g5cy8/XqVVGjJ9bY6tQyEfNymlCx2lekzcAQgyGBoAQg6EBIMRgIscADvzJepB/s+7zwdkRdmTVg1eD/KZlO0BueXB3cB50x0HXthk/y82pOoBcL+rq+/AZ9ZZf5yhuEZpYvADkcjP6kTGt1dgel7t0kJa4FgNwsWzYTmCbcVNcPf+AK+oAUvhaz5UxjcnrAMJiALXS5q43eS4/rEbAsjBGEKvhi9fy82cE8fhV25s5TgzgDYAQg6EBIMRgIrsAf/JBXNjxw7E5wXllEhd2LPgWptF++aaTQV52UHUS/mgM037tW3BCkKct2yz3Ymqv4yn8nlirWurpjYyCLt8hym5bMLWnEx8Lt4ttWlpTLhKVpcBNcZUGtEUpsNuA3YyeLAXWKMuuwxAXwJGLQcTdVl8OGrNkqa+YCizShDARaApTgV1/FqbuCG8AhJgMDQAhBkMDQIjBRI4BXN2MU33W3nlDcLZPQ7+967+eRtlbg1/apVJy/7JnKeiaduwBOdahtv+c2b0bdIN9mNrzOtXEIO/QAOgm5qAfnGrGeuWiryYCJUZFPkjUpWZjKg347MRC0MkW32YtBmBVMAbgiS1IVlU7sDYRqCoNKCf5aPIUNgNJqqYPT6EU2LXke7kcdLbDGwAhBkMDQIjB0AAQYjCRYwAXvfgOkBfd+VxwHrnk90DnNONk38wj20Eeeotq6z3yFPqNDeM78YtXqxjBW9vvAtXX+7EVd2yd2iScfhr97YkO9DnnN2OL8qinRnclsDrZsuNYspt1lF9/tIzP4CfxtQ0xFWuwy1h74MVl3erktahVdQBT2A4sS7X1d1bl+X1ZQzB5u/BMlvoyRvDawBsAIQZDA0CIwdAAEGIwkWMA9q0dIDttykdruuc50B26CvP+7f+M24MPXaB84fk/xfx2rL0N5CO9qt32rDSOJfvXIRxVPtq1ODinRe7e68Ax4V0NOfwezW9OjImNv8Kvb3RUzcBgEUeN+amYeG1IHUBCjPmqhMUAorcD10rde57ux4f3AkimMhY8rFdA+vjuFPoKXq+2Xf8EjEvwBkCIwdAAEGIwkV2AxM9xiefWO88Izis/mQNd+5X7QI49iC2/1532YHB+8G+wVbh00mKQc73qvtcVx+u2V8By3nGtKrczha227R04Iag7gws/+7XJPknpAtThpqCMdhXMlXArkpcU7cBaGtByZSkw3mXDXABZCmx5YpKPdoWWN/OYdIc8vR04PA04panANaYJGUutP8vr6FrwBkCIwdAAEGIwNACEGIzt+5zVRIip8AZAiMHQABBiMDQAhBgMDQAhBkMDQIjB0AAQYjA0AIQYDA0AIQZDA0CIwfw/7TvfjA96458AAAAASUVORK5CYII=", - "text/plain": [ - "<Figure size 640x480 with 1 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "import math\n", + "from torch import nn\n", "import torch\n", + "from matplotlib import pyplot as plt\n", + "\n", + "# get cifar10 data\n", + "# import cifar10 dataset\n", + "from torchvision.datasets import CIFAR10\n", + "\n", + "# import torchvision transforms\n", + "from torchvision import transforms\n", + "\n", + "# set a seed\n", + "torch.manual_seed(0)\n", + "\n", + "# define embedding class\n", + "\n", + "\n", + "class Embedding(nn.Module):\n", + " def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):\n", + " super(Embedding, self).__init__()\n", + " self.patch_size = patch_size\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.return_patches = return_patches\n", + " self.class_embedding = nn.Parameter(torch.randn(1, out_channels))\n", + " self.classify = extra_token\n", + " self.conv = nn.Conv2d(in_channels, out_channels,\n", + " kernel_size=patch_size, stride=patch_size)\n", + " self.patch_conv = nn.Conv2d(\n", + " in_channels, out_channels, kernel_size=patch_size, stride=patch_size)\n", + " self.norm = nn.LayerNorm(out_channels)\n", + " self.proj = nn.Linear(out_channels, out_channels)\n", + "\n", + " def get_patches(self, x, patch_size=8):\n", + " # get the patches\n", + " patches = x.unfold(2, patch_size, patch_size).unfold(\n", + " 3, patch_size, patch_size).to(x.device)\n", + "\n", + " return patches\n", + "\n", + " def get_pos_encoding(self, d_emb, max_len):\n", + " pos = torch.arange(0, max_len).float().unsqueeze(1)\n", + " i = torch.arange(0, d_emb, 2).float()\n", + "\n", + " div = torch.exp(-i * math.log(10000) / d_emb)\n", + "\n", + " sin = torch.sin(pos * div)\n", + " cos = torch.cos(pos * div)\n", + "\n", + " pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n", + "\n", + " return pos_encoding\n", + "\n", + " def forward(self, x):\n", + "\n", + " # embedding = self.patch_conv(x)\n", + " embedding = self.get_patches(x, self.patch_size)\n", + "\n", + " # flatten the embedding\n", + " embedding = embedding.reshape(x.shape[0], -1, self.out_channels)\n", + "\n", + " if self.classify:\n", + " class_embedding = self.class_embedding.expand(x.shape[0], -1, -1)\n", + " embedding = torch.cat([class_embedding, embedding], dim=1)\n", + "\n", + " \n", + "\n", + " # normalize the embedding\n", + " embedding = self.norm(embedding)\n", + "\n", + " # project the embedding\n", + " embedding = self.proj(embedding)\n", + "\n", + " # add the positional encoding account for batch size\n", + " pos_encoding = self.get_pos_encoding(\n", + " self.out_channels, embedding.shape[1]).to(x.device)\n", "\n", + " embedding = embedding + pos_encoding\n", "\n", "\n", - "def sinusoidal_encoding_table(n_position, d_hid, padding_idx=None):\n", - " '''Generate sinusoidal position encoding table'''\n", - " encoding_table = torch.zeros(n_position, d_hid)\n", - " position = torch.arange(0, n_position).unsqueeze(1)\n", - " div_term = torch.exp(torch.arange(0, d_hid, 2) * -(math.log(10000.0) / d_hid))\n", - " encoding_table[:, 0::2] = torch.sin(position * div_term)\n", - " encoding_table[:, 1::2] = torch.cos(position * div_term)\n", - " if padding_idx is not None:\n", - " encoding_table[padding_idx] = 0.\n", - " return encoding_table\n", + " if self.return_patches:\n", + " patches = self.get_patches(x, self.patch_size)\n", + " patches = patches.reshape(\n", + " x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size)\n", + "\n", + " return embedding, patches\n", + " else:\n", + " return embedding\n", + "\n", + "\n", + "# use the embedding class\n", + "patch_size = 16\n", + "embedding = Embedding(patch_size=patch_size, in_channels=3, out_channels=8, return_patches=True, extra_token=True)\n", + "\n", + "\n", + "# get the training data\n", + "train_data = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())\n", + "# get data loader\n", + "train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True)\n", + "\n", + "# take a batch of images\n", + "x, y = next(iter(train_loader))\n", + "\n", + "# run the embedding on the images\n", + "embedding, patches = embedding(x)\n", + "\n", + "print('embedding shape: ', embedding.shape)\n", + "print('patches shape: ', patches.shape)\n", + "# plot the patches\n", + "no = int(32 / patch_size)\n", + "fig, ax = plt.subplots(no, no)\n", + "for i in range(no):\n", + " for j in range(no):\n", + " ax[i, j].imshow(patches[0, i * no + j, :].permute(1, 2, 0))\n", + " ax[i, j].axis('off')\n", + "\n", + "# plot the embeddings\n", + "# plot the embeddings\n", + "no = int(32 / patch_size)\n", + "fig, ax = plt.subplots(no, no)\n", + "for i in range(no):\n", + " for j in range(no):\n", + " ax[i, j].imshow(embedding.squeeze()[i * no + j, :].detach().numpy().reshape(1, -1))\n", + " ax[i, j].axis('off')\n", + "\n", "\n", - "seq_len = 100\n", - "embedding_dim = 64\n", "\n", - "pos_encoding = sinusoidal_encoding_table(seq_len, embedding_dim)\n", "\n", - "print(pos_encoding)\n", "\n", - "# plot the position encoding\n", - "fig, ax = plt.subplots(1, 1)\n", - "ax.imshow(pos_encoding.detach().numpy())\n", - "ax.axis('off')\n", "\n" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[ 0.0000, 0.0000, 0.0000, ..., 1.0000, 1.0000, 1.0000],\n", - " [ 0.8415, 0.6816, 0.5332, ..., 1.0000, 1.0000, 1.0000],\n", - " [ 0.9093, 0.9975, 0.9021, ..., 1.0000, 1.0000, 1.0000],\n", - " ...,\n", - " [ 0.3796, -0.4645, -0.9086, ..., 0.9997, 0.9999, 0.9999],\n", - " [-0.5734, -0.9435, -0.9914, ..., 0.9997, 0.9998, 0.9999],\n", - " [-0.9992, -0.9163, -0.7687, ..., 0.9997, 0.9998, 0.9999]]])\n" + "Files already downloaded and verified\n", + "Files already downloaded and verified\n", + "epoch: 0 batch: 0 loss: 2.338132619857788\n", + "epoch: 0 batch: 100 loss: 2.208583354949951\n", + "epoch: 0 batch: 200 loss: 1.97751784324646\n", + "epoch: 0 batch: 300 loss: 2.0264248847961426\n", + "epoch: 0 batch: 400 loss: 2.054466724395752\n", + "epoch: 0 batch: 500 loss: 1.817784309387207\n", + "epoch: 0 batch: 600 loss: 1.836090087890625\n", + "epoch: 0 batch: 700 loss: 1.5311310291290283\n" ] }, { - "data": { - "text/plain": [ - "(-0.5, 63.5, 99.5, -0.5)" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAGFCAYAAAASDy0NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA9KElEQVR4nO2deZxcVZn371JLd/W+Jp10IPtCBBKiSJCAAgKK6MyoqMAwjqKIo7ihvK/juIyODjgC6ujoDChuML6KG4qCE0AFBLKAQEhIIGtn6yVdne6uru3e+/7hZ+45v+dat6vJ1p3z+/51nj619b31eeo8ux0EQWARQozEOdYfgBBy7KACIMRgqAAIMRgqAEIMhgqAEIOhAiDEYKgACDEYKgBCDCZR7QPnfPlLR/JzkGPEC2/5xrH+COQI4UzfMv5jjsLnIIRMUqgACDEYKgBCDMauthhozY4TUR6bE66fGp0Fe1uH20DuG60DOZdPh+tiAd0QfknoJF32bdzzUbT1fflfRWR8LTvuKhzH5VKvOesJkM9q3Azy4tQ+kNvdUrhucFzYq7HxXjra74tj4fV2bf72HGnoAyCExEIFQIjBVG0C7OzpArk7Ua/2yiOwtyY/A+R1o7NB3jzSGa73jjbC3lCuFuRCPhmuPWEeBNJcKGuy+K/sccwHePxEzAPJFDMXOtaivH8VXpiTl+wE+dXtG8P1qbU7YG+mi9+DJkddx4ydhL2kjeYDTYTDD00AQkgsVACEGAwVACEGU3Uq8Nn3fwDkr5/1/XB9UQYf25rpBXl6Igtye3J+uN6QRH/B9kQryAdc9eK5fAr2yiIM5evqrIw2pTDrLVvYnGC7+4HYEo/VX0fa/PKhk9wn0HTHYyA37FwG8ubz54Lcs7wpXL96Vgvsrax/HuQFyb5w3eoU8X0c/OpJn0DcdaN/4PDBK0mIwVABEGIwVACEGEzVPoDFNw6D/F7rinCt+wMsK+oTODWVF6/2vHU4yFnCJ6CtfUvYlGUUI6a7pxnvUi0Kn4DuUIj4FuJ8ApPQHxCsPAVk56EnQZ5lLQN5l6XSvH9rLbLEgxGVKgL+gD+DN6RBXnPtuiXEvfQCzFWgT+DFwytHiMFQARBiMFWbAP4LmPY5+46Tw/UnO14PezOWfg/kU1I1IC9Ijobr4ZrdsJfz8Vg/5qkU0pInj4Ii1Kel+wZyLxC6Tp7qte1I2nAkCqg9Wb5PnEkwCUOE+z+G4bnuD3SD7K/ZCHJnqzIZ9tRjGPDxWqwYbXLHwnXGLsBeMplFOfAqfkZHHPFl2jB58fAEQIjBUAEQYjBUAIQYTNU+gH3vXgHytK+pFNL8yS+Hve91rQT5E52PgNzmqJLfuckDsNebagB5sEbFFEfKadgrSp+Ap/SZL4xx36/sL/jzH7SliOXZdpzNWTlEOC6TwCfw8Eu/DfLKSz8M8syb94Lc8CfVIahp+kzY29HWDvIzGZXmPS05BHuNDoaGa+wxkF2tXtux0D8QTRtWj2VIcGLwahFiMFQAhBgMFQAhBlO1D+Dt77kH5N/+eHG4nrk6C3s/XXEqyBc3/QnkV9Yqm22aizroBOET2J9W5acHSthdeLSEOQOFkvp3ymW0EwNX+AA8YYDr2aXS5p+ACyCyrT13Qq3FjhL/M4Z2++lvfArk3fdjum/5iU3hunlLB+yNzML7saVd7c+t64e9jsRBkBtEuXCNq25IUvRv88UNcWTaN6kangAIMRgqAEIMhgqAEIOp2gfw/uatIN/+pteGaz0nwLIsK7MO8wJ+tRh9AsvTKi+g3sbY/owElh1PS6j48b5kE+wNJIVPIKls0GK5co6AZVmW7Yq2XzFThSJ5AWDYW4KplRfw8e9eCfLjV98E8spzRV7An9SHrNmKJb71CzAvYKBb5XRsbUFfw5w0PrfTxfve4Ki8AOkDcMTvlqPfH5YKTwheHUIMhgqAEIOp2gS4+LlLQJ77ZjV1ZOw2PMZ3rsc0z/85G0NJl7c8Gq6XpVEHtQqVNF0zAXYmc7DXmMT3OZhQZcc5FyfROI4IJTnijRztGCmnBsWFBSdwTB+3e9AxYPZXN4DcdxV26qk/dz/Izk9V2x9v1x7Ya+iZBvJQnzLJerqaYW9vA8ozk4MgN/mqfLhGmGvSJLAYBnzR8ARAiMFQARBiMFQAhBhM1T6A/L/hBJ87v/nlcP3m098He6lndoE8vGUeyOsWqtZRS1PYEqzBwXTSNle1D2sXIcLGZDPINYlSuE66+DolF/9VzxFhQF0eJxVYDwvGlwpbVlz7sMmAXV8P8hUbMSz4uYU/A/lTy64K13XPb4O92t2jKPc1h+u+g/g++1oxpJtNYyvp6a5KFS4GJdhLC+eJDz4B2fqNYcE4eDUIMRgqAEIMhgqAEIOp2geQvmcNyCXNtt11HuYBzH4Q0zybNs8H+ZGzlHxJ/Quw1+liem+rM6LWiRHYa0hgHkAmoUpKUwlsRV5w0Ra0Hdn2S60DsWfJ9mFHi6MwVWjTR3CcT9ud+L+e9wVsx3XNaeo3Y949aLc7+wZAru1Vdn5fFu/Hvjy2fhuoQx/BaKDyOJqED8AP8GJ42sVxJp+bZVLDEwAhBkMFQIjBVG0CFC5+GchXblGVX0vPwWGfYxk8GrZswakw6/er6TN7puFH6BRZnQ3ama7ZwVTgBleaAOqomHbx6DoqUoFjU4MjYUBx/rZjqgGnWGrwXX/1ZZA//pkLQN72z2h2tS3vDdf2DEz9lanBmT6VVuxm8T735kT35yY0/YZ91Tk6LzoIl6y41GCmBU8EngAIMRgqAEIMhgqAEIOp2geQuQ5Tdvf9UKXz3v5/boa96xe+A+S06ByT3d0Vrjcv6YS9k1NYFpqxVThITpPRp89almXVuroPAMtaExEfgAgDanIkYzdGHneK0CQo+Y2j3cUQW5BHf83n910I8lWzHw7X35/3OthLidTgdJ+6X6mssPlztSD3l0QY0Feh5ZL4nfICvLeedg/8SPcgOVGaU4R0eAUIMRgqAEIMhgqAEIOp2gdw98JfgnzJXco2POWTmObZvxxLPdvuwHTfzC6VfrphrBv2XpvBFlS1tirrbRI+gAbhA6hLKPs1KfIAEu44PgDddI+L+08UeOrk6xi8avUHQJ55MX4lHvgDvvG/vOW+cP3vC7DkuvO3aG+7B1QOQVr4AA6OYPr4gSLmjgz76juVD/B1Sxb6AOLwxYVyYsc8mQdPAIQYDBUAIQZTtQnw5UGs6POzqlvv4wUMJQ2swON267cxtFTfo45lGw52wV62RXSl1ar66kQor0GYBBltwGSNCG9JkyBqAmhdfmS3oEiObkwq8DEY7nEoLLkRh3Tu+Bwe67tvxaN552UqZXdoIV7TrnpM57Wy6rXTWUwb9kexa3NWmgCelgoc4GNLQUHI+iBRceTniT8WngAIMRgqAEIMhgqAEIOp2gfwvW9cBHLTBcrG/vR2HPy47BQcJCrLgxt6lA33wgF8bn832nvd2ifMiHCcPkDSsiwr4+o+APQlJMcpB7Z1WaaIRuz8I9Pp91iUB/tbd4L83dP+API//fHVIG8rqdDetAX9+GKdbfjaWnlwTRb9Be4wfvWyeUwNHtJ8ADkfQ4ZFURauh/pkKjDLg+PhCYAQg6ECIMRgqAAIMZiqfQDTvrkW5OdvXxqukw/Mhr2br7wN5K/MvBjk1G6VQ5Dtb4G93R6mEZ8cKHuvxkZ7rs4R+QVai7C0I3wAMjVY+gD00L4t8wAsxK6w/kschc6+h8K+d68A+ZTUIyB7Q5gncOeQevxFMzbC3kNdLwfZ0cqDU1mRlzGCvp7hPNr5Q+WYPIC48mBxP+LKgz1ODeIJgBCToQIgxGCoAAgxmKp9AM6C2SDfesZ3wvXnv44TZc9+N07xvWEexodrH1NtxJN92BJse7ED5HKtyilISh+AjXalXgsgW4K5trD3RL4/1AbExf3/4gMUkRZhQUzdwKGUBx8m3v6ee0D+TN8ykN2FONn5jueVrf6t5bfD3m9mnQ1yo/46Q1i3kRzB3JBcHmsQDpZVDcioyAMoBcIHAGvZMhzRcwZYGswTACFGQwVAiMFUbQJsuh67tr6yVh2pb1i3CfbSNr7s4EI8iKXuU6Glml48hm3PY2pwvnFzuK638SioH/kty7JqbCXXir2UCAO6ManBcYNDLeuYnNSjHKay4/c3Y9r20lvfB3LmXHxhf51645ecjnsj3fh70pRQ3wNneBT2kiMY/h0ZE2HBktYRyJflwGgK6ncybnCoZbE8WMITACEGQwVAiMFQARBiMFX7AH7/qq+AfNOBU8J14KE9/WQRQ3AHReuoab6Sa/vQRts2iiHDYe2xTQnUVzUi5FanhwFFKnBKyNIHEDfwN9olWJdlDa98rraehKnAFz93Cchz/xunOG3/LPpdOm9TKboZB0N3o92i7ZpWBh6MYglvckSkW+fRrh8pq/eNlAOLEt+S9lLesRixPIXhCYAQg6ECIMRgqAAIMZiqfQD7PbT3/uNXajLQvGUjsHdbP+qVrvloV9pJ9VqZPrTNe4abQR7ylb2HM4Qsq0YE6PXy4KgPQOQB2DIVWMsDiNj8VmVZPvYItQg7UqZt/oszQE5vXAPyJ0/eAfL3nj4zXPd7GNvPdOP3wG5SycD+wAHYS42KPIwcfmeGi7oPAL97Mi/Acxncf7HwBECIwVABEGIwVZsAb/nJtSDP/2+VzttzIXbx2bxhKcjvOO1hkB9pVsNB030YHto9hNNlBnzsFquTFGfzGq06MONit6CENAEc2Q1GrwasPDXIsg4hmjcJpwalf41Hfns53rs31q8D+du7Vaff9YVm2Dtl2h6QB9vU1Cd/917YSw6LDk1jGNobLapj/7CHw2dLkTCg+h2THYCisroJDjsG8wRAiMlQARBiMFQAhBhM1T6ARTfhBJmyZgs2f2Euvug900F+zTlPgfzwtGXh2h3A7kGlIXxuX1mFkrwgC3uy7Fj3AdTYGAZMO+N1BdY6xTiypFQQ1xV4ikWkChe/DOSeV+I1HfKxk4+dUCG4X2aXwd4rWp4H+a529b1I+uL6j2I3J3cM033zJfU5xjwM++VFWNDTLroXjBPCjcHELsHH/39ICKkIFQAhBkMFQIjBVO0DCHI4idc9aWG4/tyCH8Pe59dhl+AlojVrvrshXNc+th/2ktmZIO8rqxyDsjWAn8GWeQDK7k87YhJNJBW4cpdg2QIsvktwvJGpdwm2I+N/I4+WT4597cNB/Ud7QH5fB077uWXgDJCduSeE6/t3ot1+xXKcKvT9DnXjZXdeZwRbtiXw62UVCuoZI57wD8hJQVoegPTXxHUJ9mW7sKnmwDkM8ARAiMFQARBiMFWbANuuXQJy4Krjk94h2LIs68ZnsYJMdo4Z7lZvm16NFWWpLB7D9habw3UpwANerY2vm9bCgCk7viOQ7BKshwFl6m/EJIDNmL2jyYvsPPTzBb8CWYa+5vzi3SB3LVf7hc34zy94GZpduU712EYH026dHIYXXWEClIvq8TIMKLsC62FA/CZGuwRPmvs1SeAJgBCDoQIgxGCoAAgxmKp9ADdccTvI63JzwvULJewE4w0Ogiw7x4x2K0OsrSQm+GTRZtubV2HAnPAB1Duov1JaaK9GDA6VHYISduXy4GhHoMo+ATkMdKoZmTcNLgD5ZbXbQJ6xGv+f3pcpuRkHQlktLg78zHdo6dUpEQgcQx9AYkxMIMqrr2auHN8RSPcJeOPcDj1VmFOCeAIgxGioAAgxGCoAQgymah/AubXY1fXijLL7r9r1GthzGkC01hZaQc7PQrtfp0b4APbn1YvlZUxXkNZsuhqZCmzHlwO7MeXA8ZOCxjEk40qHJ0FLsDv+40KQv7ESA/KL/7Ad5NS7VXuuxN14XyXFDuV3sevQPxDksWVbUvgA7KL6bYr4ACKpwOprLPMAZCowQXgCIMRgqAAIMRgqAEIMpmofwMo17wD59mW3h+uHfv8S2Ju3CNt8/XIQbf4ZM5U/QZ8SZFmWlc6ird6fU23Ch/34Ns4pLUAv8wAiPgBRDuzE1gLE5JPHlgpbh62kN1JJfJhM22n/uRbkpm2nglzeh+XaV52wK1zfsQ1LhWW+R12HavluZ9AHICcFJfKiRVtB/TaNldHmL4g8gKKeBzDF8jCONTwBEGIwVACEGEzVJkDXjXjsuuLD7wzX3avxuN13GsYBd+2cD/KrZz8XrjfX42OTQ2gu7BtVk4GGRTdYiT4pKCnKgWVYMG5SUHzY7/jCWTgH5Jr7ngDZPXEWyK/NqK4/39s/D/a2inDd3DbVwanciNOjgr378H3GRFhWKwfOl/FrKoeF6pOCPGErya4/+qQgX9xpEycF8QRAiMFQARBiMFQAhBhM1T4A+1Gc7tN+5+nhuvaP2El2+18vBjm5pRHkM0/eEq63NL8SP9AQpqIWRpWPIOtjKMmyRDqp1s5qvDBgMiYM6IzTEkwPC0aifBOZFDQJUoM3fQynMS/+GKb39p/TDbJe8huIUu41YzghakmjsvOfaJoGe7YnwrI5lJ2CsvMLJfyaRsuB9VRgvKhy2g8nBSHH339ECKkaKgBCDIYKgBCDqdoHkL0C0z5bf6TixV4BbfFXnLIZ5Of+C1uKL0urycLfaq2HPXd/Ft94pEV9Bg99AF6A/gJH02dRm1+WA1dODXZiJgf/+Q/Wi0K2D4tMCjoGPPyqr4B88d98FOTsGXhvNxZVeq/TgDkcDw+2gHxR29Ph+rFmnEKcFqXdbh7zNlztbYsiD6DgozyRVGC9PFhOKzIRngAIMRgqAEIMpmoT4OUfwKqx5+9Wx3FXGxhpWZZ11TQcFvovW2aD3O2qw1ehA4/1tVv3gJwYUTpqwENzoWz1gqwPC02J3jByWKgbU/Enu8XGTwY6jF2B44aFHiFzYZcYvDn9UpzqdGXnsyD/IPvycG3PwNDeU/vxtT44475wXWjCNFt8pGXZY+L+aCZAoYTPjU4KSmhr/E2Tw0LjMHFYKE8AhBgMFQAhBkMFQIjBVO0D+MqMNSCf/lfXhOtSPdpKK2swdJTahra6Pi0414EfoWYMQ3uJYfXag2VMW5XTgtO2sg1lObCcFiwnAyViJgNNuWnBE0gxvuwn7wf5kUu/BHLGRvt7xcNqWvD0BTWwl+vBN571UvU9KDTHO1bsAqYVO5pYFj6AoggDlmLCgLJLMEF4AiDEYKgACDEYKgBCDKZqH8DVPStBTr1VdYsdG0Nb0BF6xduPPgDddh/rEDZbUZTxao1m+0uYB1AS5Zq1ml0py33HTQ3WU4HH7QqsdxDGranWPGzRTTtBdi7FfXkvk+vVPRhciI+t60G51VXRfukDsBMYy7fz6ANwC+pK+iX8DDIPoKiXAwfxv2m+loIspwaZOC2YJwBCDIYKgBCDqdoEeOoWHBhxxxf+LVyvKcyEvccLeJYKyhiC2+upUF++UxyafXFUH1H7B4qYNlwSRzi9Y0tS7MmwoBwO6thxYUD8iLEnxcOZGnwUCHIYdr1q69+A/OoOTAXuXKdCe9vehOG56b/D3xM9LFtsFtc0hcf4oChNAE0QJkDRw/fVw4BF0dnXm1q346jDEwAhBkMFQIjBUAEQYjBV+wAa73wM5I4b1VPfWDcIe1ftOgdkBxvHWM8W28J1qQPDfpKU7gMoYCpwPqgcdEsKWy81bocg5QOIKxWOMBGbf7yOwccghrjtWuzWVPMj3L/pjBkgL35GxfpmX4/h38QPsKOwTrkZfTB2jSgIFuHfhBYGtIUPIB8pB1Z2vwwDylRgGfozHZ4ACDEYKgBCDIYKgBCDqdoHEJyJeQDv71EpoTfM/A3sPfg0TgZa0p0F+Y+jalpwc/sI7NkJ/EjJEWXFZfO1sDfqV9Zfcs6rzAOQPoD4yUAT8AlI9MdOgi7AkhuuuB3kb65aBXJyBKf9lPepFPDXdg3D3gN7TwF5xM+H61QzlojbNeg/8LNDILtF3QeA163gyS7Byicgy4HH6xJsOjwBEGIwVACEGAwVACEGU7UPoO9jeZD3/vjkcP2ly9EWbFmPLzs6HyfGPNo/J1zPa+3Hx6ZSICdHlO2ezWPsuBBIS197nkjgT1ovvhw4gr4/kWnAhxHpTpiIW0Ln3NoDIH+9fwDkjgfxmvua7X5h/aOw90Av1oTs99S9a29CX09Qi68b9MvJQNr9KIoWYDG1AHqLcMuyLC+QLcK0MuNIlkDl79PxCk8AhBgMFQAhBlO1CfDwiu+C/OZr3xiu/9/c02Fv4To87u1dhbnAo3s7wvUVJz8Oe2sy0/EDjqgy0bE8poCOBpXHO7riLC5TgV17AuXAFd9l6rNyzTtAbj8PS65T9+JEKOclKsS7KClKb4fRFNxebgrXsxqysDdUh1OFLE+EZTUTQERwY8uBDyXsJ00CxwCTgCcAQgyGCoAQg6ECIMRgqvYB/DrXDnJ5pyoLnbkaS0bdTThhdvgdmBqc7FGhpKWnYyvZtfXzQHZyygdQzqMvYdjH1GDLUummSVt0sx0nFTgZEwaMpgbra9yLZPvaFdZ/iWNQHjzjBvwKbLkG7eDF6ztAHliuQrp6yy/LsixLlGdvKqjvxYkZDDc+UTcLZFv4ACAMKFKBi2W0zfNaKrAMA/oyNVjvJD1Bd4H+XNc+Pn47j4//ghDyoqACIMRgqAAIMZiqfQD/9J0rQJ55tmon3fjQNtjzDh4Eee78fSAf+Gl3uF6YxKlBfiPGoZ2DuXAdjDXD3kEfS0q9QH0mOdHGjbQJr1wOnBA5ArHlwFM9SeCxp0H8xvefAfmzqzBPoH+F+t/3lkUpdxrTezeMqtTgU+p2wd6aevQfJIX/wCmp++Ngx/BIKnBZKwv3REuwuLwA2R6sclbJ8QtPAIQYDBUAIQZTtQkw+6t4NNx4owrtLXwPHuPd5iaQL5yOz/1Zj0r3nZHA8Fy5CUN7yf0qfOTmUF8Ne/hY39IeK6sBI0d+MTzUUfuH1AFoAgTyfSIlfvBo+eTD8hmyV5wB8gWZJ0C++nx83yVL1FH+sQKmbbstzSBvzCoT7fUt+LqlejzGy+O3k1ffC0ekApdFGLAAYUDc8wPZIah6fDk8dMrbe1F4AiDEYKgACDEYKgBCDKZqH4Dd1AjyDef+MFx/ZzZOAip1YQeg8+vvBfm+npXhusXBUF6xWYSHxlQnIjePNtiwCAPq5ZwyDJi0S0IWPgHN3oumAoMoUoEtsSls9cgDJhcrP7gG5K9lMUX30pVYrr2y/vlwfffAMtjzp7eBvLtf3csZC7Drb7EO7w8Gfy3L1sOAYnhUsYzPLWmhP8+qPgxIeAIgxGioAAgxGCoAQgymah/Axuu6Qb60Xtl0//ZK7AZbqkO7S7aOcvepacJJW5R2NqOcKao8UDeHrzvioQ+gFCi7MW3HpwKnRHmwE1MOfLTyAo4YMWXGt3Rhy6/5d74H5Ecu/RLIGe1+Xb8DpwZ1zcRU4HKfugcdIphfqrcQ6SvRpgVLH4Dn4b0t+uprLPMApA9A9gE2HZ4ACDEYKgBCDKZqE+Anr/8yyD8YVt1eBs7DoSF+CY9hsnOMP6BSdqFDi2VZhWZxZNOOggl8G2uojKnAenWXTNtMyi7A4jCohwXH7wikVwNW7hZkWUelqc8hcbUWkrUsy5p/J1b4tb1Vdl3S2FIH4kGMIFo1fepiNDhiuEcDXijbxe+MXVImgz4o1LIsKxBhQH1YqKwG9IPKv3G+qECU1YEy/Hs8whMAIQZDBUCIwVABEGIwVfsAWl2MxXxitZoM9OFzMNX3qREMGcrOMX5eGfMjQQH2is3ijX1lm7tjuHWwLMOAWtdWZ7wwYOWOQI41xcN+E+Cpm08FuXEtDvy8axTTuhen9ofr5s34WtmFKNfvUtet1sahryV0H1h2QnwVNR+ADAMGZTTOy5qdP14Y0NNvpQE2/njwBECIwVABEGIwVACEGEzVPoBzfvtBkOffoVJ03/n6LbD3VO1WkB/J4+QgPVi+38N4fLGpcrJmYgxt8eGS9AFUttWlpou0BIM8ADk5OGYyUMV3/N8HxBidcVOELOuoJBE03ok2v7toPsif3XASyBecsClcN23JwV7feWh/1z6h8j/kJJ1yvbjPScwVCWJSgS2RB1D09FRg/ErL8mCC8OoQYjBUAIQYDBUAIQZTtQ9gyY1ZkL3NL4RrX+TVr0ijLfjNXow1O1op6NZSK75ui+gBrZEUPoCD0gdQ8ZmWlRT2dVweQMIRk4HEa8XmBUT2Jnew2T9rGcg7z8XmXDWr8fE/O0Pdy8Xb9sDeiTPwfiR68d7qePV4/e2UaAxe1moBSsIH44lpwb76vkXKf2VtgP4ZJn2lxpGHJwBCDIYKgBCDqdoE8LfjcMfEiar289ahQdi7phnDgg9vnwPygnbVTejZPHYTqm0SNb+OOt4lxvBoPlrC9NJ8zLQcV8iyHNiN6QgUx4TShOVjD9N0n0Nh4GOYX/36WU+C/PTlmN/rpdrDdbm3H/ZOb8dj/dMDqkNQzscJn249mnq2CAP6oyrE6JTFdROpwCXNBBg3FXiSm2RHG54ACDEYKgBCDIYKgBCDqdoHsPc9K0DWG/J+ZV077J105m6Qnc1Y+1nqUnbYhhFME57WNAyynVQfUfoARorSByAtfYWcFizDgJAKHGkNNYHy4PHSe48QujthIm6Jh077HsiyS/PrNuI/0NmpxXB9vIavqMf64GcOTAvXIwEGaevqhK8njffSyio/kSPCgI4sB/a1cmBfpALHtASTyHB21HN0/METACEGQwVAiMFQARBiMFX7AK56969A1tt+PfW1U2Dv1rlng9y8GW240W7VanpTthP25jVhbLkvpWxDdwxtzrEixo5LMfZepBzYkm3C46YDCx8ACFM7nfSXuQ6Q9ZZflmVZbkMDys/sUEJzE+wtTfWC7A+rVnADIn23tQ5LiQPhAwi0MnFHtAW3ZR6AVzkPwJfTgjVnic9UYJ4ACDEZKgBCDKZqE+C9zdtALjSpdN83/w6rvtasnAfywufxuLf3FSqUNNqPx8jzu54DuT89PVy7eUwfLRTw4+eDyv+Oa8WHATEVWFQDTvFjfhyf/s7lILsvzYLcvhKrA1P3qmGizksWw153AoeD+mMq1LfHQ1OioxY7RQ/XoiloaR2enZK4H3jrIAzoH8bUXxkWdI7DsCBPAIQYDBUAIQZDBUCIwVTtA7ho0xtA/sb8O8N1eQeWCretw8lA7rZ9II++Za4SetFunJPuA3ltrSoltscwnbRcRPt01NdfCx/riFTguK7A7jhhQF2W04ClvwAqfieaJqzvHyE3xOyvPgPy0EXYBbjnXPyQ8x9Sad0jC9B/I6dA66nCu0ptsDW95iC+b1p0jva01OyyTAXGh5Y8PRVYlANHOgIdnosqp1rLrsdThan5qQkhhwUqAEIMhgqAEIOp2gdQ/uJ0kK/4wN+F69bpGNNtX4f2ndc3AHKqW5WJOusxPjwriY8NMqru2C5gW6mggM/NB8oG9cTUYZkH4MoYr54HcCjlwFMMu6kR5MZfPgXyvA/hV6R34exwPbgQ7e2CKPnV27ltK2DK8cx0FuRnM+g/cLUpT05RdBCW5cATSQXWvgeenCRlYLcwngAIMRgqAEIMpmoTIPWbNSDnOleG6+wq7AjUeA+GlgLROWbp9L3hentfPezNcLEjkJ9RoT23fwj27CLqrxyEAdEscYSuS8pU4LhqwJhwUcQcOIRjZCBNjaPQNXjjdRiyXfxPaL59asY9IL91+UfD9egCNMl6ymh2ObXKfNuWw65Q57ZsArlci0d3XZKpwDIM6OkdgWRX4Al0BJKDQpIVHnc8wRMAIQZDBUCIwVABEGIwVfsA8pecDnLbzzaE6403Yllo/Y9GQXZF55gzW7aG696+ubDX5qId5tWrTjHOHrQ57QLayHoYUHZ7kV2B3Uior/py4GMSBoy4Aw7PlKGfveEWkN+19kMgz0v+HuSBFeraLJmHw0GfLmKo2GlQ/p1dIzg4dHoH+nO82pjfonJ8ObCnpQKXRSpwJCyoXSfxMuOif6ec4yRmyBMAIQZDBUCIwVABEGIwVfsAmq7bCXJptQrGXn7GH2FvfTu2dwpmYBroyzMPhOvf9K2CvXpRUlqqVx8xWcRUU1f4APRy4OiUF0T6APS8gKQt04Slv0CtZTlwBHju5LMbm0VQPXgrdmV+cAx/I0475YVw/ao2bN/2x5H5+FqtyvfTexDva6uLeRrlmsrXxi6htS7zAIpax2HZGdqLSQUmPAEQYjRUAIQYDBUAIQZTtQ/gFwt+A/JZr706XH+w7SbYe9vS94Gcb8OpLwuSY+oD9KMtmHHwsaV6paOCIuYBSB9AzlfPlaWecuqtE6kFqJwHIInLA4j1CUxC8/Oc334Q5P85/xaQ/37T34L88XmqNuDExCDsXdNzGcjJNpX/n8vi9W9zsG6gVBvnA0Cj3xY+AN+vnAfgHYV6iqkMTwCEGAwVACEGU7UJcMPAApCzl6mje5ODaZ69p9WC7GHjX6vN0fYHsfxUUtRNgDKe/cQp0hrx1OeITnXBf1WGAXEyUHxHIJ0p2R1IOxUvuTELW10Xogl24IEukM9eqsq1kzbe2J17sPPvjE51zd0sHsUbHJQ9aQLotpQXHwYMYsKAcR2B4g09M+AJgBCDoQIgxGCoAAgxmKp9AD/8+vkgf/9jt4Tru0awDHTktDGQ/aJo96RNUQmGsQWYnLhSqlc2WyDCQS5GBa2cp4UBx5n6Ekn31SzCaMfguMlAsiWYLB1W68noLQh29ID8oT3ngNy9Gu9P+v2aXS+m4SR70H+Q0zLAU8IHkBFh2TK6jSzbVft2WfoAxPSlmHLgyGSgmBZhvggdy++QcxxGFHkCIMRgqAAIMRgqAEIMpmofQOet60Be+gn11MvXvxb2LjnpaZCfH8Zy4CFf+Qj8AgbzxwI07Et613A5kVXkAYxqCQcl+VgHdZ3MA0hpqcHjtQWfkrH/Cuy5+jSQd/0E97ufXAvyOm1Kz3Q3B3v1OCTaGutURnMqi3tykrD0AViaD8ASPgDZEkzPAygLGz/SJhwmA4n3PA5t/PHgCYAQg6ECIMRgqjYB7EXYvfdTvflwnVqNXX/f/tGHQX60Fp+7raTpHRF6GfYx1FeuDyo+1i2gPObpXYHjkZrPiU0Frvxq43YEggdPvjPnu6++G+R7LkaTwHfxSt3Wd3a4flUTTvdp6MF7N7hUPbdhu5zMJI7mtcLMAhNApICLVGBLNwF8mfrL37g4eHUIMRgqAEIMhgqAEIOp2gew+XqM02z/+Vnhes4DfbB38j9iiCfjbAb5wZxWWuygLdjn40cq1VcOuTkiFXi0rFJRi0F8qM4V5rceFpxIKrAkYtVPpCuw3NblIxR5fE/TDpB/sQ1LeoNXLAP5vg3qfg0vxDLw2h7s7pToUPejJitqwgURH0BCfQ8iZeAiFdjy9TCgnARU2SfAcmCeAAgxGioAQgyGCoAQg6naB/DQOV8F+Z2f//tw7W16HvZkOu+cBNqKXxiaF66dOtiy9pQxpyCok0FfRULkAeQ0H8BE8wAm0hJMN82nelrwRZveALL9CvQB7DovA3LTerV+PDUb9hbtx2nBM9vVfU9kW2M/h18r8nuT2lfTEz4Z8ZWwY/MA0LESVw48Xgn58QhPAIQYDBUAIQZTtQmwQ5RreRu3hGu3GY/t94/hce/iDIaHntw/M1zPaMSuwNuL7SAn6zVzQuTdukWRClxW4cfSOKc5V8jQEWicwSBxTMgkkI89BkMsyl/Ebk47LsPfhIVLMEzof0Ldn0IL2m/eYBbkJc3q/9k2hN+fUiAq/KQJoIUBrbE8PlaW8WlP9fzqqwEJTwCEGA0VACEGQwVAiMFU7QO44i4c+Dl/hbLdc50YKvrP3Sgvm/sjkLO7G8P19BZ8n20F7B7UUKfsPygRtSzLzaOtni+rf6cYE+6xLMtypT9BDwNGusFOoCvwJEC6EuI+Yuo3a0D+9C3YJXhxei/ub7wgXDefuAjfV3R3Wlo3EK53DE+DvZwIFadrS/iZk1pp98go7MWGASfQBZj+AJ4ACDEaKgBCDIYKgBCDqdoHsOimbSA/95HZFR/r/+lEkP/QhXJml3rbUgf6C3bkMIegPaM6z+olopZlWW4BfQDDJe11J6jboBw4MjWo+nLgCHaF9SQhf8npIF/Z+KR4BJZ2e4OD4bppC+Z3BKK0e3FapQb/amQx7OV8jPtnakSLZ0gFjp8MpPsAZB6AnAwk5TjkhOlo9sjUhycAQgyGCoAQg6ECIMRgqvYBBEWM237gNb8O108Oz4K9p//jZJB/vmIZyPU9yoYb68CJsruGm0Ge19QfrvuSaI86RbTRiloeQGkcWy/SFtyqPg9AJ74F2OSn5TrM9b9rpBHkBalekN1GbX875ghY9VgbMDsxFK6DHE4RygpbvVH4AALNBxCIcmA7xgcgy4FlLYCv3XlPJEz4LAcmhJgEFQAhBlO1CbD1WgzjvL/l/nC9rX4D7L33iW6Q1207AeQ5Peq4N3ASdgvKDdWDfNa0reG6P4UhQqeAOaGlkjru5YP4f821KqcCR8KAQp6M6b8vlp8tuBfkOXe/C+TlS7aD7C9S5cPB2mdgz503G+RprnaMF2nCB3y87y1pNBHGUlpnIjHo1RHlwPrtiR7rUT6U9F8fzMTjIyTIEwAhBkMFQIjBUAEQYjBV+wBuufw2kFePKRvovFq0263N20FMb8aJs6nd+8N1/hy0BYtigkxXKhuun0lh+yqniCmiZa0rcCniA5B2vfABaCEgafO7MeXAkY7BshRX2z8kz4E0XQ+TG+KGgQUgz7sTr+mG8+aDXK/dyva1+FqlrmZ8rKPurZzu0+dhuLElNQZyLq2+X4EvrnEkDKjWZU+0AItJBZb+gSM2fmkSwxMAIQZDBUCIwVABEGIwVfsAVtZkQV5+/z+E6/96xXdgz89jG+fmLaKssu9AuMx34iSaRBY/0nQtndRKY9qwXUS70i+q1tP5ANOGLUuUmwp0u989UrZg3PTfvyQfBX749fNB7vjdoyDPspaD/MJlysae1or93A7Oip8ArLOvhK3kW1PY9mtXSn0P7EgegMjLgHJgvIgyFTgOT06UnsD98ALpN5oav61T41MSQo4IVACEGEzVJsCZj10F8gl3qqPVJztxwGTTNOwU07hlGGRvSHUUrunA8I+zqwHkjoR6bCBMAKuEJkBQUvqsKCfCyCOaTAXW0zwn0BFoqqcFd966DmS3eybIzuObQF71BfWV6ZuFKd8j3fh7Ugi0Tr8iPrq31AxyexK/M74WBnTF0dwuSRNArWVHID+SGqxVA07GFk1HGZ4ACDEYKgBCDIYKgBCDqdoH0H0D6gp7/RPheuBk7Cxb8xIMuaX/tB1kT+sIe0LrIOz1ZTGtuM1RZaJBjegIlEW70Sormy4aBkQcofugK/B4k4FiwoQRn8AhmJmB7ms4QpOD7UVzQd51IZZcd92MXX+umvZIuP6/C98De6OzMI34gKe+B04aQ4T7CpgCvqJBTCFOqfsjA3l2ubIPQNr8kenAE+gK7In7HP+NmprwBECIwVABEGIwVACEGEzVPgBrDbZ/cjrVFN+Zq4dgb/d5mOY54/4DINtJFc9f3LwP9g4OYWy51VWxZL9W+AB6xUTZotJneR8fO17HVz32H8kDsOWEGO09p3gewObra0FeOe9ZkAdW4wTglenHw/XgAvz9qOvOgrzLU3a/XYvvs38MczramtGf42k+AGl72zIVWBN9kQosfQJ67F/umQhPAIQYDBUAIQZTtQkweOUZIJfq1PFp2n8+DnuF65bGvpbT2hyuT848BXtPZrH6rEGrqvJqxHDQshgaqVkE0Y5AiBvTEci14jsCxR37ZUcg3JTPO4QjaOSp2mtP4Gj70DlfBbnBwet25nkfrvjc/EKs+jxz2h6QNxW6wrVdh0Ng+8YwLNjsYjWgV1P5f7BFV2BH7wgUSQWOl+E9K+5EkSalM0XTinkCIMRgqAAIMRgqAEIMpmofwKprHwP5ueFp4dr7Fr7MmxY/AfKTre0gBx0q3XRxGu3GVBaHkKZt9drlWkzrTIlOs06pciqwb1UO5VmW7Aocn/qLXYFjX3bSs6OM4bkZCSzPrjm3D+Q/FtQ9OG3uTthb2fwCyBvGVEg3aEAfQHYUw4ANDvoTvFTMhY1JBfY8mQosugKzHBjgCYAQg6ECIMRgqAAIMZiqfQBf6loPcm+nitte/rL3wd7ftnwN5PUn4MTZwjRlD85K4FTYxBDagroPwKtFfSWnzeh5AAWRCiw7viZt9Cc4miEp8wAOhfi8gMP2Ni+aK+7Ce9d9Kpb//vOiX4B86/6zw/XrOjCH46T0bpDv7VP5IF4T+hrGRvCr12BjWnc5HZcH4AlZrYNx8gDk9GDT4QmAEIOhAiDEYKgACDGYqn0A79x5Fsg3zLw3XPech/bdkhTGfA8uxFbfxQZlh7U7YtrPQcwJ1yeslGqF/VZCu1HPA8j5ooX4OEAtwDjlwHFtwadaefCim7aB3PuaOSBf9Dls73bN06o8+DMX3gN7zQ7+nmwdUFOfOprF1KARfGyDg9cYbp90pEQmA6l1IGz8sqwFiMkDOHyen6kDTwCEGAwVACEGU7UJsPGWl4D8jmtU9955r8Rj5JCP6aRZ0TlGj9BlhAkQ5DAsqFMWJkAgjoJ6GDDaEUhO+0kIufLRPZIaPMWO+XEERUy97vgpTgLa+Wns1NO6Tl23Ey5GU08OxBztU/uNzRh2dUfwXtaIY76nhwHF69rluDAgbB1S2M8/hGGhUwWeAAgxGCoAQgyGCoAQg6naB9Dw34+CvKvtzHD97etuhr1fj84AubAIfQJ+Uc56UQSj6APQp/pGfQCVW4IVfPzX5JQXid72y5EtwWICRNIfEJ0MpIcMcWsyeBK2XrsY5Nmfx2nB/7r/fJA71otpTDGk+tQ9KGCjaCsZ8QGICT5a1NCWNde+8OdoLcICD3/Tyn7lyUBx7cH+Eno6+VQvA/9feAIgxGCoAAgxGCoAQgymah+AvwrbdXf9+PlwvewfMc3zIy+sAnnFHGwdtXtEGYQ5H+PQflGUhWrNmj3MOI4EfeN8ABOJ6co24BOZDjzVuOXy20D+lyffDvK9D+GFWvScmhy0vYz+miZhGNf0KbnQgu+bFK6EtC3KtyEVWPxOyTwA7WsQyMlA4kZ7Mb95nrytx4mdHwdPAIQYDBUAIQZTtQmQvR6r9Nr+Sg38fLKAFWM7nsQw4N+97scgr6+dHa77hQlg+Xi8y/nqXF+uER9KHOudkpLHPEwxlmFAmbYaNxkojqlW/SdZWZMFue9teKzv/i7aXd7Bg+H6wdx82FsqOgJletV1HDgVz9ON2EA4MlnHq9HCp674nYrpCCRvXXRSkHofdgXmCYAQo6ECIMRgqAAIMZiqfQC/W/Z9kC8+/x/C9Sd34uSfDmwgbK1683aQa7R43fZyvRXHsBbq8zLx9jaGATEFdDyrXteEcR2A/tK+jkz3jbUyD+e04BfJykexY/N3X/YtkD999QX4hBYVz7u7dxps5TswlFfbp25ISZQDJ0VHIOmTgY5ALj438Ct3BLJkGDCoHAaMCwmaAq8AIQZDBUCIwVABEGIwVfsAfjGK9t6Oy5ThlXpwLuzNfQInyp6QwNZRJUtNBP5Dbh6+kYP2XtbXpwNX7wPIe2iPFmUqsMDVTEVX5gzIdmIwHXgCeQATtfntCmvLOmy1xCfcgC+84meifHZwEN/2zFPD9cat+PXJJMRk535VBp5swvuRGhFdggV+Ws8DEOXjsgxcz+EVPoByIHxBMSXA7ApMCDEKKgBCDKZqE+Czt78N5G+887/C9RdvuwL2/K1Y/SfpdtVx8OlcN+w5NZjCe8BT5kNQg0c/iaulAhc9UQ0Y+0zxOuOEAY8ngrXPgPy5fuz+7C5EE23PirpwndmMr/VUE6aAnzgwFK5bmzCPOzlSuSuUZVmWn9buQcQEiAsDiteJhAFtbS/+92+8LlLHAzwBEGIwVACEGAwVACEGU7UP4IR/fxrkC96vYm43r3kO9nwxbeaFMnYFXphUduTG7HTYczP42F5PDRZ1MuXYz6iXA+eFDyDS7UU+F9bxHgMYDir2plp58OCVZ4B8x924X/8qlIdOU6XfM3+F13hPF6Z1B0O7wnV3Qx72hkdFmbEc6ZPSfQCyIxB+D3SXjS3DgJFyYP7m6fBqEGIwVACEGAwVACEGU7UPwGltBvmrgyeGazlhNjGjC+QHcwtAntOobMOdB7Bd7Nw6tOH2lNR+qkb4AETasJ4HUBA+gNI4abdxUelDyQuYkE9APvYQJttWy6prHwN545V4r7Z+CvMyLpitukHvvGU27NXNxXvpa5OeT8xgG+CNo20g692fLcuynLSS7YTI6Sjg9w1TgWHLCmLyAGRLMBNbhPEEQIjBUAEQYjBUAIQYTNU+gI3XYZ73c7+eGa7nnzoMe7kOLP/9ZS8aZhfVqSTyfB/Gg/0GnAy0t9gcrjM12H5cTo11inotgJwKO44PQOvlFbH5Y3LCx50OPAmQ/7r+Eb/Uhf3bLtyA//vHT+kBeXF6b7j+9E7M76/fhSOAAy1eP7d2APY25WaBnA/Qv5NMabLwAcjpwHotgO3JcmDmAcTBq0GIwVABEGIwVZsAd19yC8gfeZPqJrv7XDz6+dj8xdr+Apb8PjtThYBqekXKrigb7S2oVOBGYQLIMlGnrM62JWEClCag6yIdgSZgEsQSyRt+cS9zOHnnzrNAdk7Fo/lbGtaAnNACpl42C3v1PeL+aMxOYZcoawxTgwsiFTgNJkB8RyDb11KzRSqwJ1KBPc0E8CZoDvgQY4wvZ54q8ARAiMFQARBiMFQAhBhM1T6AjI12l95Kqv6z2DbqYA7t+Mxa9BH8cYWaKlsjTMNiMzoQ9uYaw3VLGifXjkkfQEnZaKWy7AYbb3C/2HLgcdOCJ2FYUGfjzdgCrO9teJ306cyWZVlp/f8R03ySe7Mg+2nV+XemOwR7QR59AKM+XqdMWqX7BsIHEMip0OXKXYHlfff1lmCW9A/Ix07ue3c44AmAEIOhAiDEYKgACDGYqn0A5973IZCXLFITYz6z4OewtzaHk4LuueOVID/aPydcZ/rQ3s6LKbKjOdU+bGY72pF5kSLqFDUfgMgDKE4kD+AQ4v6TILQ/IRp++CjIf/1PmNb9lQMvBfnldS+Ea7etFV+sF9N9nQbVIqzdRV9CkMecgeEA72V9Siv5dcUUIeEv0F02wlUVsesnGvs/3uHVIMRgqAAIMZiqTYAlXzwA8rbL1LDQCzJ4vJuXXAfyH7YsB3nz3o5wPacPu7sMnIQhxNyokltnjMLe3iR2oLFL6vxXLqNuKwXxqZtuzOE92hFIyZOx+m8i+Kvw3ny09dsgz/nlu0B+eomqCvVnY0dnOWXInTc7XDc7+FULRGffYR87D9Vpg0YLWhfpPz9Z3A+tI5C8HXFhwEPpAOTLgbFTNDWYJwBCDIYKgBCDoQIgxGCq9gEEO3eD/Po3qHDQthJ2fJ2XxAkx9s69ILu7TgrXyb5B2Cs0ow+gOKJsw/Ykvo+V6ADR1lKBvTKmFHvj6DpH7wgkwn6OXf1s4WiHIH2Ne4fkPZDm64t8sez16Ff5PWboWl33o227vqTSvqctwA/RuAY/hNeqvge1Ntr40geQ9bGLVGNKTYjqTbbjc0UY0NZ9AOJWxZUDS/zDdVGnEDwBEGIwVACEGAwVACEGU7UPYM81p4H8i86vhuv39lwIe1+YcR/IXhZTeOt6lK3lDB6EvVKTiO2PKhu0JYH2qp1EO9/WWkX5Htqc0TwAkTOqMZ7N706kHDiO8VqEHYW84oeW/wDkpb/DuP+i320HuZRRadzZhfgBm9KYspvvUP4c147/rcl6GOtvTKhU4d6EeG4guwJXngzkxZQHH8m0YH3a8Xj/+7Fk8n4yQsgRhwqAEIOp2gS45l1Y8behqMI4Dzx0Muz9/LUvgCyHeDb0qOf6B7H6rNwiBkT0qY/Y5oowYBI/vp4K7EdSgeW/iiaAngrsinNkJCx4HIWHfjrSCfLMO9B0Ku/dB3L7elUBOHg+PtZtRfMt11H118s6UMbQcWNChQH9JN7LiKWkmQCOGAwSHQ6qXkt2BJJ4ovPQlCv1rAKeAAgxGCoAQgyGCoAQg6naSHtn006Qlz/2jnDdvRrt6W8tPRPkpg603TM9Kpznj2Kn33QT5qK6W9VkoGYXw4CB8AFYZe1zCB9AMZDDQieS3lv5sVNhOGgcn739bSDPunctyG73TJD9LTvC9RmzMQw7MA0HyI51KqO5EGDJOORIW5bVL3wATboPIIX3zg1kKrB2f8ZJBY7rDn0o5cFTFZ4ACDEYKgBCDIYKgBCDqdoHcMGzfwNy053KZss8sgn2nl91Esj1J2KaZ2K36h5b9tF/0NmE/oKDo5oPwBmDvSCFNqgzovkTRDw4mgeAOJoujEwHjpQHT6BL8CT3CZzw70+DbDc1gDxwDk4LbvqBKgt/ffs22Pta90KQxzrU/z7sY+s3O4H3brCE96czo1LE/QTeS5nUjdOBcU9e/aOVCjxV4BUgxGCoAAgxGCoAQgymah+A/UVsv9X4kLIdvTG0zTvWoyE2tBB9AC1PbdFeGO277vosyFtGVOvpBkdMl0kLazCr/Al2WfoAqm/bPF45sL4v6wJibf5DCDMHMt9gnGnH1eK0Yf7+vgsw7n/gDLzmrfe2heszax6CvX/tFhOZO5Tdf0BcUlv4b/oLWFfw0vrKtQCSuJZgvmwJpv3mTTTu72n3OhnzuKkETwCEGAwVACEGU7UJkLwPU0TtxsZw7S6cB3tNT/aBvPVvp+F+XqX7OjXYBXh2BgdM7hhRx66MOAb7Kfz4rlfZBJCpwL4M9R2mrsDHDPh3hRkSYy5svK4L5AVLdoF8aedzIP/2JavCdXcC03dHMWJotXeoUN4+0fHHrsX7ni1iN6FGV31H/NQ4R3Vfm9QUMQEqdwSKvMxhMqumEjwBEGIwVACEGAwVACEGYweB7HtECDEFngAIMRgqAEIMhgqAEIOhAiDEYKgACDEYKgBCDIYKgBCDoQIgxGCoAAgxmP8PeMDMSeU4oWYAAAAASUVORK5CYII=", - "text/plain": [ - "<Figure size 640x480 with 1 Axes>" - ] - }, - "metadata": {}, - "output_type": "display_data" + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 217\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39maccuracy: \u001b[39m\u001b[39m'\u001b[39m, accuracy \u001b[39m/\u001b[39m \u001b[39mlen\u001b[39m(test_loader))\n\u001b[1;32m 216\u001b[0m \u001b[39m# train the model\u001b[39;00m\n\u001b[0;32m--> 217\u001b[0m train(model\u001b[39m.\u001b[39;49mcuda(), train_loader, criterion, optimizer, epochs\u001b[39m=\u001b[39;49m\u001b[39m10\u001b[39;49m)\n\u001b[1;32m 218\u001b[0m \u001b[39m# test the model\u001b[39;00m\n\u001b[1;32m 219\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mtesting\u001b[39m\u001b[39m'\u001b[39m)\n", + "Cell \u001b[0;32mIn[1], line 176\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, train_loader, criterion, optimizer, epochs)\u001b[0m\n\u001b[1;32m 174\u001b[0m loss\u001b[39m.\u001b[39mbackward()\n\u001b[1;32m 175\u001b[0m \u001b[39m# update the weights\u001b[39;00m\n\u001b[0;32m--> 176\u001b[0m optimizer\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 177\u001b[0m \u001b[39m# print the loss\u001b[39;00m\n\u001b[1;32m 178\u001b[0m \u001b[39mif\u001b[39;00m i \u001b[39m%\u001b[39m \u001b[39m100\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/optim/optimizer.py:140\u001b[0m, in \u001b[0;36mOptimizer._hook_for_profile.<locals>.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 138\u001b[0m profile_name \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mOptimizer.step#\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m.step\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(obj\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m)\n\u001b[1;32m 139\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mautograd\u001b[39m.\u001b[39mprofiler\u001b[39m.\u001b[39mrecord_function(profile_name):\n\u001b[0;32m--> 140\u001b[0m out \u001b[39m=\u001b[39m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 141\u001b[0m obj\u001b[39m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 142\u001b[0m \u001b[39mreturn\u001b[39;00m out\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/optim/optimizer.py:23\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 22\u001b[0m torch\u001b[39m.\u001b[39mset_grad_enabled(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdefaults[\u001b[39m'\u001b[39m\u001b[39mdifferentiable\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[0;32m---> 23\u001b[0m ret \u001b[39m=\u001b[39m func(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 24\u001b[0m \u001b[39mfinally\u001b[39;00m:\n\u001b[1;32m 25\u001b[0m torch\u001b[39m.\u001b[39mset_grad_enabled(prev_grad)\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/optim/adam.py:226\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure, grad_scaler)\u001b[0m\n\u001b[1;32m 223\u001b[0m state[\u001b[39m'\u001b[39m\u001b[39mmax_exp_avg_sq\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mzeros_like(p, memory_format\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mpreserve_format)\n\u001b[1;32m 225\u001b[0m exp_avgs\u001b[39m.\u001b[39mappend(state[\u001b[39m'\u001b[39m\u001b[39mexp_avg\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[0;32m--> 226\u001b[0m exp_avg_sqs\u001b[39m.\u001b[39;49mappend(state[\u001b[39m'\u001b[39m\u001b[39mexp_avg_sq\u001b[39m\u001b[39m'\u001b[39m])\n\u001b[1;32m 228\u001b[0m \u001b[39mif\u001b[39;00m group[\u001b[39m'\u001b[39m\u001b[39mamsgrad\u001b[39m\u001b[39m'\u001b[39m]:\n\u001b[1;32m 229\u001b[0m max_exp_avg_sqs\u001b[39m.\u001b[39mappend(state[\u001b[39m'\u001b[39m\u001b[39mmax_exp_avg_sq\u001b[39m\u001b[39m'\u001b[39m])\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] } ], "source": [ - "def get_pos_encoding(max_len, d_emb):\n", - " pos = torch.arange(0, max_len).float().unsqueeze(1)\n", - " i = torch.arange(0, d_emb, 2).float()\n", + "# create a transformer class model\n", + "from torch import nn\n", + "import torch\n", + "from torch.nn import functional as F\n", + "import math\n", "\n", - " div = torch.exp(-i * math.log(10000) / d_emb)\n", + "from torchvision.datasets import CIFAR10\n", + "# import torchvision transforms\n", + "from torchvision import transforms\n", "\n", - " sin = torch.sin(pos * div)\n", - " cos = torch.cos(pos * div)\n", + "class Embedding(nn.Module):\n", + " def __init__(self, patch_size, in_channels, out_channels, return_patches=False, extra_token=False):\n", + " super(Embedding, self).__init__()\n", + " self.patch_size = patch_size\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.return_patches = return_patches\n", + " self.class_embedding = nn.Parameter(torch.randn(1, out_channels))\n", + " self.classify = extra_token\n", + " self.patch_conv = nn.Conv2d(\n", + " in_channels, out_channels, kernel_size=patch_size, stride=patch_size)\n", + " self.norm = nn.LayerNorm(out_channels)\n", + " self.proj = nn.Linear(out_channels, out_channels)\n", + "\n", + " def get_patches(self, x, patch_size=8):\n", + " # get the patches\n", + " patches = x.unfold(2, patch_size, patch_size).unfold(\n", + " 3, patch_size, patch_size).to(x.device)\n", + "\n", + " return patches\n", + "\n", + " def get_pos_encoding(self, d_emb, max_len):\n", + " pos = torch.arange(0, max_len).float().unsqueeze(1)\n", + " i = torch.arange(0, d_emb, 2).float()\n", "\n", - " pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n", + " div = torch.exp(-i * math.log(10000) / d_emb)\n", "\n", - " return pos_encoding\n", + " sin = torch.sin(pos * div)\n", + " cos = torch.cos(pos * div)\n", "\n", - "seq_len = 100\n", - "embedding_dim = 64\n", + " pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n", "\n", - "pos_encoding = get_pos_encoding(seq_len, embedding_dim)\n", + " return pos_encoding\n", + "\n", + " def forward(self, x):\n", + "\n", + " embedding = self.patch_conv(x)\n", + "\n", + " # flatten the embedding\n", + " embedding = embedding.reshape(x.shape[0], -1, self.out_channels)\n", "\n", - "print(pos_encoding)\n", + " if self.classify:\n", + " class_embedding = self.class_embedding.repeat(x.shape[0], 1, 1)\n", + " embedding = torch.cat([class_embedding, embedding], dim=1)\n", "\n", - "# plot the position encoding\n", - "fig, ax = plt.subplots(1, 1)\n", - "ax.imshow(pos_encoding.squeeze().detach().numpy())\n", - "ax.axis('off')\n" + " # normalize the embedding\n", + " embedding = self.norm(embedding)\n", + "\n", + " # project the embedding\n", + " embedding = self.proj(embedding)\n", + "\n", + " # add the positional encoding account for batch size\n", + " pos_encoding = self.get_pos_encoding(\n", + " self.out_channels, embedding.shape[1]).to(x.device)\n", + "\n", + " embedding = embedding + pos_encoding\n", + "\n", + " # apply the dropout\n", + " embedding = F.dropout(embedding, p=0.1, training=self.training)\n", + "\n", + "\n", + " if self.return_patches:\n", + " patches = self.get_patches(x, self.patch_size)\n", + " patches = patches.reshape(\n", + " x.shape[0], -1, self.in_channels, self.patch_size, self.patch_size)\n", + "\n", + " return embedding, patches\n", + " else:\n", + " return embedding\n", + " \n", + "\n", + "\n", + "class MyTransformerLayer(nn.Module):\n", + " def __init__(self, d_model, nhead, dropout=0.1, batch_first=False):\n", + " super(MyTransformerLayer, self).__init__()\n", + " self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)\n", + " self.linear1 = nn.Linear(d_model, d_model)\n", + " self.dropout = nn.Dropout(dropout)\n", + " self.linear2 = nn.Linear(d_model, d_model)\n", + " self.norm1 = nn.LayerNorm(d_model)\n", + " self.norm2 = nn.LayerNorm(d_model)\n", + " self.dropout1 = nn.Dropout(dropout)\n", + " self.dropout2 = nn.Dropout(dropout)\n", + "\n", + " def forward(self, src, src_mask=None, src_key_padding_mask=None):\n", + " src2 = self.self_attn(src, src, src, attn_mask=src_mask,\n", + " key_padding_mask=src_key_padding_mask)[0]\n", + " src = src + self.dropout1(src2)\n", + " src = self.norm1(src)\n", + " src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))\n", + " src = src + self.dropout2(src2)\n", + " src = self.norm2(src)\n", + " return src\n", + " \n", + "# define the transformer class\n", + "class MyTransformer(nn.Module):\n", + " def __init__(self) -> None:\n", + " super(MyTransformer, self).__init__()\n", + "\n", + " # define the embedding\n", + " self.embedding = Embedding(patch_size=16, in_channels=3, out_channels=64, return_patches=False, extra_token=True)\n", + "\n", + " # self.encoder = nn.TransformerEncoderLayer(d_model=64, nhead=8, batch_first=True)\n", + " self.encoder = MyTransformerLayer(d_model=64, nhead=8, dropout=0.1, batch_first=True)\n", + "\n", + " # create the linear layer for all the outputs from the transformer\n", + " self.fc = nn.Linear(64, 10)\n", + "\n", + "\n", + " def forward(self, x):\n", + " # get the embedding\n", + " embedding = self.embedding(x)\n", + " # apply the transformer\n", + " out = self.encoder(embedding)\n", + " # out = self.fc_out_trans(out)\n", + "\n", + " # get the classification token\n", + " classification_token = out[:, 0, :]\n", + "\n", + " # out = out.mean(dim=1)\n", + " # pass the output through the linear layer\n", + " # out = self.fc(out)\n", + " out = self.fc(classification_token)\n", + "\n", + "\n", + " return out \n", + "\n", + "\n", + "# get the data \n", + "train_data = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())\n", + "test_data = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())\n", + "\n", + "# get data loader\n", + "train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)\n", + "\n", + "# create the model\n", + "model = MyTransformer()\n", + "\n", + "# define the loss function\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "# define the optimizer\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "\n", + "# defing the training loop for gpu\n", + "def train(model, train_loader, criterion, optimizer, epochs=10):\n", + " # set the model to train mode\n", + " model.train()\n", + " # loop through the epochs\n", + " for epoch in range(epochs):\n", + " # loop through the batches\n", + " for i, (x, y) in enumerate(train_loader):\n", + " # move the data to the gpu\n", + " x = x.cuda()\n", + " y = y.cuda()\n", + " # zero the gradients\n", + " optimizer.zero_grad()\n", + " # get the output\n", + " out = model(x)\n", + " # calculate the loss\n", + " loss = criterion(out, y)\n", + " # calculate the gradients\n", + " loss.backward()\n", + " # update the weights\n", + " optimizer.step()\n", + " # print the loss\n", + " if i % 100 == 0:\n", + " print('epoch: ', epoch, 'batch: ', i, 'loss: ', loss.item())\n", + "\n", + " # get parameters\n", + " for name, param in model.named_parameters():\n", + " if param.requires_grad:\n", + " # pass\n", + " \n", + " # get the names of the layers which do not change\n", + " if param.grad.abs().mean() == 0:\n", + " print(\"layer does not change\")\n", + " print(name)\n", + " \n", + "\n", + "# defing the testing loop for gpu\n", + "def test(model, test_loader, criterion):\n", + " # set the model to eval mode\n", + " model.eval()\n", + " # initialize the loss\n", + " loss = 0\n", + " # initialize the accuracy\n", + " accuracy = 0\n", + " # loop through the batches\n", + " for i, (x, y) in enumerate(test_loader):\n", + " # move the data to the gpu\n", + " x = x.cuda()\n", + " y = y.cuda()\n", + " # get the output\n", + " out = model(x)\n", + " # calculate the loss\n", + " loss += criterion(out, y).item()\n", + " # calculate the accuracy\n", + " accuracy += (out.argmax(1) == y).float().mean().item()\n", + " # print the loss\n", + " print('loss: ', loss / len(test_loader))\n", + " # print the accuracy\n", + " print('accuracy: ', accuracy / len(test_loader))\n", + "\n", + "# train the model\n", + "train(model.cuda(), train_loader, criterion, optimizer, epochs=10)\n", + "# test the model\n", + "print('testing')\n", + "test(model.cuda(), test_loader, criterion)\n", + "\n" ] } ], diff --git a/pth_lighting_example/mnist.ipynb b/pth_lighting_example/mnist.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3b55ac56167001fbbbd7f31eb86a2f0c31d101fd --- /dev/null +++ b/pth_lighting_example/mnist.ipynb @@ -0,0 +1,496 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import pandas as pd\n", + "import seaborn as sn\n", + "import torch\n", + "from IPython.core.display import display\n", + "from pytorch_lightning import LightningModule, Trainer\n", + "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchmetrics import Accuracy\n", + "from torchvision import transforms\n", + "from torchvision.datasets import MNIST\n", + "\n", + "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", + "BATCH_SIZE = 256 if torch.cuda.is_available() else 64\n", + "\n", + "class LitMNIST(LightningModule):\n", + " def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):\n", + " super().__init__()\n", + "\n", + " # Set our init args as class attributes\n", + " self.data_dir = data_dir\n", + " self.hidden_size = hidden_size\n", + " self.learning_rate = learning_rate\n", + "\n", + " # Hardcode some dataset specific attributes\n", + " self.num_classes = 10\n", + " self.dims = (1, 28, 28)\n", + " channels, width, height = self.dims\n", + " self.transform = transforms.Compose(\n", + " [\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,)),\n", + " ]\n", + " )\n", + "\n", + " # Define PyTorch model\n", + " self.model = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(channels * width * height, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(hidden_size, self.num_classes),\n", + " )\n", + "\n", + " self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", + " self.test_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", + "\n", + " def forward(self, x):\n", + " x = self.model(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " self.val_accuracy.update(preds, y)\n", + "\n", + " # Calling self.log will surface up scalars for you in TensorBoard\n", + " self.log(\"val_loss\", loss, prog_bar=True)\n", + " self.log(\"val_acc\", self.val_accuracy, prog_bar=True)\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " logits = self(x)\n", + " loss = F.nll_loss(logits, y)\n", + " preds = torch.argmax(logits, dim=1)\n", + " self.test_accuracy.update(preds, y)\n", + "\n", + " # Calling self.log will surface up scalars for you in TensorBoard\n", + " self.log(\"test_loss\", loss, prog_bar=True)\n", + " self.log(\"test_acc\", self.test_accuracy, prog_bar=True)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return optimizer\n", + "\n", + " ####################\n", + " # DATA RELATED HOOKS\n", + " ####################\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == \"fit\" or stage is None:\n", + " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == \"test\" or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)\n", + "\n", + "\n", + "model = LitMNIST()\n", + "trainer = Trainer(\n", + " accelerator=\"auto\",\n", + " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", + " max_epochs=3,\n", + " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", + " logger=CSVLogger(save_dir=\"logs/\"),\n", + ")\n", + "trainer.fit(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n", + "del metrics[\"step\"]\n", + "metrics.set_index(\"epoch\", inplace=True)\n", + "display(metrics.dropna(axis=1, how=\"all\").head())\n", + "sn.relplot(data=metrics, kind=\"line\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import pandas as pd\n", + "import seaborn as sn\n", + "import torch\n", + "from IPython.core.display import display\n", + "from pytorch_lightning import LightningModule, Trainer\n", + "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n", + "from pytorch_lightning.loggers import CSVLogger\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchmetrics import Accuracy\n", + "from torchvision import transforms\n", + "from torchvision.datasets import MNIST, CIFAR10\n", + "\n", + "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", + "BATCH_SIZE = 256 if torch.cuda.is_available() else 64\n", + "\n", + "class MNISTModel(LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.l1 = torch.nn.Linear(32 * 32, 10)\n", + "\n", + " def forward(self, x):\n", + " return torch.relu(self.l1(x.view(x.size(0), -1)))\n", + "\n", + " def accuracy(self, y_hat, y):\n", + " preds = torch.argmax(y_hat, dim=1)\n", + " return (preds == y).float().mean()\n", + "\n", + " def training_step(self, batch, batch_nb):\n", + " x, y = batch\n", + " loss = F.cross_entropy(self(x), y)\n", + " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " # log accuracy\n", + " y_hat = self(x)\n", + " self.log(\"train_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "\n", + " return loss\n", + " \n", + " def validation_step(self, batch, batch_nb):\n", + " x, y = batch\n", + " loss = F.cross_entropy(self(x), y)\n", + " self.log(\"val_loss\", loss, prog_bar=True)\n", + " # log accuracy\n", + " y_hat = self(x)\n", + " self.log(\"val_acc\", self.accuracy(y_hat, y), prog_bar=True)\n", + " return loss\n", + " \n", + " def test_step(self, batch, batch_nb):\n", + " x, y = batch\n", + " loss = F.cross_entropy(self(x), y)\n", + " self.log(\"test_loss\", loss, prog_bar=True)\n", + " # log accuracy\n", + " y_hat = self(x)\n", + " self.log(\"test_acc\", self.accuracy(y_hat, y), prog_bar=True)\n", + "\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.Adam(self.parameters(), lr=0.02)\n", + " \n", + "\n", + "# Init our model\n", + "mnist_model = MNISTModel()\n", + "\n", + "# Init DataLoader from MNIST Dataset\n", + "# train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())\n", + "# train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)\n", + "\n", + "# test_ds = MNIST(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())\n", + "# test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)\n", + "\n", + "# Init DataLoader from CIFAR-10 Dataset\n", + "train_ds = CIFAR10(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())\n", + "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)\n", + "\n", + "test_ds = CIFAR10(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())\n", + "test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)\n", + "\n", + "\n", + "\n", + "# Initialize a trainer\n", + "trainer = Trainer(\n", + " accelerator=\"auto\",\n", + " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", + " max_epochs=5,\n", + " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", + ")\n", + "\n", + "# Train the model ⚡\n", + "trainer.fit(mnist_model, train_loader)\n", + "\n", + "# Test the model\n", + "trainer.test(mnist_model, dataloaders=test_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'cifar_example'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[13], line 46\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[39mreturn\u001b[39;00m DataLoader(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcifar10_test, batch_size\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbatch_size)\n\u001b[1;32m 45\u001b[0m \u001b[39m# import modules\u001b[39;00m\n\u001b[0;32m---> 46\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mcifar_example\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcifar_transformer_modules\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmy_transformer_layer\u001b[39;00m \u001b[39mimport\u001b[39;00m MyTransformerLayer\n\u001b[1;32m 47\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mcifar_example\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcifar_transformer_modules\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39membedding\u001b[39;00m \u001b[39mimport\u001b[39;00m Embedding\n\u001b[1;32m 49\u001b[0m \u001b[39m# create resnet model for cifar10 classification\u001b[39;00m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cifar_example'" + ] + } + ], + "source": [ + "import pytorch_lightning as pl\n", + "from torch.utils.data import DataLoader, random_split\n", + "from torchvision.datasets import CIFAR10\n", + "from torchvision import transforms\n", + "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n", + "\n", + "from torchvision.models import resnet18\n", + "\n", + "\n", + "# create datamodule for cifar10\n", + "class CIFAR10DataModule(pl.LightningDataModule):\n", + " def __init__(self, data_dir: str = \"./\", batch_size: int = 32):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.batch_size = batch_size\n", + " self.transform = transforms.Compose(\n", + " [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n", + " )\n", + "\n", + " def prepare_data(self):\n", + " # download\n", + " CIFAR10(self.data_dir, train=True, download=True)\n", + " CIFAR10(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + " # Assign train/val datasets for use in dataloaders\n", + " if stage == \"fit\" or stage is None:\n", + " cifar10_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n", + " self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [45000, 5000])\n", + "\n", + " # Assign test dataset for use in dataloader(s)\n", + " if stage == \"test\" or stage is None:\n", + " self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.cifar10_train, batch_size=self.batch_size, num_workers=16)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.cifar10_val, batch_size=self.batch_size, num_workers=16)\n", + "\n", + " def test_dataloader(self):\n", + " return DataLoader(self.cifar10_test, batch_size=self.batch_size)\n", + " \n", + "\n", + "# # import modules\n", + "# from cifar_example.cifar_transformer_modules.my_transformer_layer import MyTransformerLayer\n", + "# from cifar_example.cifar_transformer_modules.embedding import Embedding\n", + "\n", + "# create resnet model for cifar10 classification\n", + "class CIFAR10Model(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.model = resnet18(pretrained=False, num_classes=10)\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " y_hat = self(x)\n", + " loss = F.cross_entropy(y_hat, y)\n", + " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " # log accuracy\n", + " self.log(\"train_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " y_hat = self(x)\n", + " loss = F.cross_entropy(y_hat, y)\n", + " self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " # log accuracy\n", + " self.log(\"val_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " return loss\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " y_hat = self(x)\n", + " loss = F.cross_entropy(y_hat, y)\n", + " self.log(\"test_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " # log accuracy\n", + " self.log(\"test_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.Adam(self.parameters(), lr=0.001)\n", + "\n", + " def accuracy(self, y_hat, y):\n", + " preds = torch.argmax(y_hat, dim=1)\n", + " return (preds == y).float().mean()\n", + " \n", + "\n", + "# # create resnet model for cifar10 classification\n", + "# class CIFAR10Model(pl.LightningModule):\n", + "# def __init__(self):\n", + "# super().__init__()\n", + "# self.embedding_size=64\n", + "# self.criterion = torch.nn.CrossEntropyLoss()\n", + "# self.embedding = Embedding(\n", + "# patch_size=8, \n", + "# in_channels=3, \n", + "# out_channels=self.embedding_size, \n", + "# return_patches=False, \n", + "# extra_token=True\n", + "# )\n", + "# self.self_attention = MyTransformerLayer(d_model=self.embedding_size, nhead=16, dropout=0.3)\n", + "# self.fc = nn.Linear(self.embedding_size, 10)\n", + " \n", + "# def forward(self, x):\n", + "# embedding = self.embedding(x)\n", + "# context = self.self_attention(embedding)\n", + "# # get the first token\n", + "# context = context[:, 0, :]\n", + "\n", + "# # context = context.mean(dim=1)\n", + "\n", + "# # get the classification\n", + "# context = self.fc(context)\n", + " \n", + "# return context\n", + "\n", + "# def training_step(self, batch, batch_idx):\n", + "# x, y = batch\n", + "# y_hat = self(x)\n", + "# loss = F.cross_entropy(y_hat, y)\n", + "# self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# # log accuracy\n", + "# self.log(\"train_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# return loss\n", + "\n", + "# def validation_step(self, batch, batch_idx):\n", + "# x, y = batch\n", + "# y_hat = self(x)\n", + "# loss = F.cross_entropy(y_hat, y)\n", + "# self.log(\"val_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# # log accuracy\n", + "# self.log(\"val_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# return loss\n", + "\n", + "# def test_step(self, batch, batch_idx):\n", + "# x, y = batch\n", + "# y_hat = self(x)\n", + "# loss = F.cross_entropy(y_hat, y)\n", + "# self.log(\"test_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# # log accuracy\n", + "# self.log(\"test_acc\", self.accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)\n", + "# return loss\n", + "\n", + "# def configure_optimizers(self):\n", + "# return torch.optim.Adam(self.parameters(), lr=0.001)\n", + "\n", + "# def accuracy(self, y_hat, y):\n", + "# preds = torch.argmax(y_hat, dim=1)\n", + "# return (preds == y).float().mean()\n", + " \n", + "\n", + "# create trainer\n", + "trainer = pl.Trainer(\n", + " gpus=1 if torch.cuda.is_available() else None,\n", + " max_epochs=5,\n", + " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", + " logger=CSVLogger(save_dir=\"logs/\")\n", + ")\n", + "\n", + "# create datamodule\n", + "dm = CIFAR10DataModule()\n", + "\n", + "# create model\n", + "model = CIFAR10Model()\n", + "\n", + "# train model\n", + "trainer.fit(model, dm)\n", + "\n", + "# test model\n", + "trainer.test(model, datamodule=dm)\n", + "\n", + "# save model\n", + "trainer.save_checkpoint(\"cifar10_model.ckpt\")\n", + "\n", + "# print the metrics\n", + "print(trainer.logged_metrics)\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}