From 2406277fcd320e3543f90bc3d62099f56cc98eae Mon Sep 17 00:00:00 2001
From: Maciej Wielgosz <maciej.wielgosz@nibio.no>
Date: Mon, 20 Feb 2023 15:19:27 +0100
Subject: [PATCH] progress towards own model (cifar experiments)

---
 cifar_example/vis.ipynb | 219 ++++++++++++++++++++--------------------
 1 file changed, 111 insertions(+), 108 deletions(-)

diff --git a/cifar_example/vis.ipynb b/cifar_example/vis.ipynb
index af4392f..a3cd485 100644
--- a/cifar_example/vis.ipynb
+++ b/cifar_example/vis.ipynb
@@ -240,141 +240,144 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 19,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "patches shape inside :  torch.Size([1, 3, 2, 2, 16, 16])\n",
-      "pos_embedding inside :  torch.Size([4, 4])\n"
+      "tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,\n",
+      "          0.0000e+00,  1.0000e+00],\n",
+      "        [ 8.4147e-01,  5.4030e-01,  6.8156e-01,  ...,  1.0000e+00,\n",
+      "          1.3335e-04,  1.0000e+00],\n",
+      "        [ 9.0930e-01, -4.1615e-01,  9.9748e-01,  ...,  1.0000e+00,\n",
+      "          2.6670e-04,  1.0000e+00],\n",
+      "        ...,\n",
+      "        [ 3.7961e-01, -9.2515e-01, -4.6453e-01,  ...,  9.9985e-01,\n",
+      "          1.2935e-02,  9.9992e-01],\n",
+      "        [-5.7338e-01, -8.1929e-01, -9.4349e-01,  ...,  9.9985e-01,\n",
+      "          1.3068e-02,  9.9991e-01],\n",
+      "        [-9.9921e-01,  3.9821e-02, -9.1628e-01,  ...,  9.9985e-01,\n",
+      "          1.3201e-02,  9.9991e-01]])\n"
      ]
     },
     {
-     "ename": "TypeError",
-     "evalue": "reshape(): argument 'shape' must be tuple of SymInts, but found element of type float at pos 2",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
-      "Cell \u001b[0;32mIn[6], line 68\u001b[0m\n\u001b[1;32m     65\u001b[0m patch_size \u001b[39m=\u001b[39m \u001b[39m16\u001b[39m\n\u001b[1;32m     66\u001b[0m pos_embedding \u001b[39m=\u001b[39m PositionEmbedding(patch_size\u001b[39m=\u001b[39mpatch_size, in_channels\u001b[39m=\u001b[39m\u001b[39m3\u001b[39m, out_channels\u001b[39m=\u001b[39m\u001b[39m8\u001b[39m)\n\u001b[0;32m---> 68\u001b[0m pos_embedding(img\u001b[39m.\u001b[39;49munsqueeze(\u001b[39m0\u001b[39;49m))\n\u001b[1;32m     70\u001b[0m \u001b[39m# show the embedding and patches\u001b[39;00m\n\u001b[1;32m     71\u001b[0m embedding, patches \u001b[39m=\u001b[39m pos_embedding(img\u001b[39m.\u001b[39munsqueeze(\u001b[39m0\u001b[39m))\n",
-      "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1190\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1191\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1192\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1195\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
-      "Cell \u001b[0;32mIn[6], line 58\u001b[0m, in \u001b[0;36mPositionEmbedding.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     56\u001b[0m embedding \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnorm(embedding)\n\u001b[1;32m     57\u001b[0m \u001b[39m# add the position embedding\u001b[39;00m\n\u001b[0;32m---> 58\u001b[0m pos_embedding \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mget_cosine_position_embedding(x, patch_size\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpatch_size)\n\u001b[1;32m     60\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mpos embedding: \u001b[39m\u001b[39m'\u001b[39m, pos_embedding\u001b[39m.\u001b[39mshape)\n\u001b[1;32m     61\u001b[0m \u001b[39mreturn\u001b[39;00m embedding \u001b[39m+\u001b[39m pos_embedding\n",
-      "Cell \u001b[0;32mIn[6], line 38\u001b[0m, in \u001b[0;36mPositionEmbedding.get_cosine_position_embedding\u001b[0;34m(self, x, patch_size)\u001b[0m\n\u001b[1;32m     35\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mpos_embedding inside : \u001b[39m\u001b[39m'\u001b[39m, pos_embedding\u001b[39m.\u001b[39mshape)\n\u001b[1;32m     37\u001b[0m \u001b[39m# get the sine and cosine embedding\u001b[39;00m\n\u001b[0;32m---> 38\u001b[0m pos_embedding \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mcat([torch\u001b[39m.\u001b[39;49msin(pos_embedding), torch\u001b[39m.\u001b[39;49mcos(pos_embedding)], dim\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m)\u001b[39m.\u001b[39;49mreshape(\u001b[39m1\u001b[39;49m, no_patches, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mout_channels)\u001b[39m.\u001b[39mpermute(\u001b[39m0\u001b[39m, \u001b[39m2\u001b[39m, \u001b[39m1\u001b[39m)\n\u001b[1;32m     39\u001b[0m \u001b[39m# expand the position embedding\u001b[39;00m\n\u001b[1;32m     40\u001b[0m pos_embedding \u001b[39m=\u001b[39m pos_embedding\u001b[39m.\u001b[39mexpand(batch_size, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n",
-      "\u001b[0;31mTypeError\u001b[0m: reshape(): argument 'shape' must be tuple of SymInts, but found element of type float at pos 2"
-     ]
+     "data": {
+      "text/plain": [
+       "(-0.5, 63.5, 99.5, -0.5)"
+      ]
+     },
+     "execution_count": 19,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
     }
    ],
    "source": [
-    "# create position embedding\n",
-    "class PositionEmbedding(nn.Module):\n",
-    "    def __init__(self, patch_size, in_channels, out_channels):\n",
-    "        super(PositionEmbedding, self).__init__()\n",
-    "        self.patch_size = patch_size\n",
-    "        self.in_channels = in_channels\n",
-    "        self.out_channels = out_channels\n",
-    "        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)\n",
-    "        self.norm = nn.LayerNorm(out_channels)\n",
-    "        \n",
-    "    def get_patches(self, x, patch_size=8):\n",
-    "        # get the patches\n",
-    "        patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)\n",
-    "\n",
-    "        return patches\n",
-    "\n",
-    "    # define cosinusoidal position embedding\n",
-    "    def get_cosine_position_embedding(self, x, patch_size=8):\n",
-    "        # get the patches\n",
-    "        patches = self.get_patches(x, patch_size=patch_size)\n",
-    "\n",
-    "        # print shape of patches\n",
-    "        print('patches shape inside : ', patches.shape)\n",
-    "\n",
-    "        # get the batch size\n",
-    "        batch_size = patches.shape[0]\n",
-    "        # get the number of patches\n",
-    "        no_patches = (32 / patch_size) ** 2\n",
-    "        # get the patch size\n",
-    "        patch_size = patch_size\n",
-    "        # get the position embedding\n",
-    "        pos_embedding = torch.arange(0, no_patches).unsqueeze(1) / (10000 ** (torch.arange(0, self.out_channels, 2) / self.out_channels))\n",
+    "import math\n",
+    "import torch\n",
     "\n",
-    "        # print self.out_channels\n",
-    "        print('self.out_channels inside : ', self.out_channels)\n",
     "\n",
-    "        # print shape\n",
-    "        print('pos_embedding inside : ', pos_embedding.shape)\n",
     "\n",
-    "        # get the sine and cosine embedding\n",
-    "        pos_embedding = torch.cat([torch.sin(pos_embedding), torch.cos(pos_embedding)], dim=1).reshape(1, no_patches, self.out_channels).permute(0, 2, 1)\n",
-    "        # expand the position embedding\n",
-    "        pos_embedding = pos_embedding.expand(batch_size, -1, -1)\n",
+    "def sinusoidal_encoding_table(n_position, d_hid, padding_idx=None):\n",
+    "    '''Generate sinusoidal position encoding table'''\n",
+    "    encoding_table = torch.zeros(n_position, d_hid)\n",
+    "    position = torch.arange(0, n_position).unsqueeze(1)\n",
+    "    div_term = torch.exp(torch.arange(0, d_hid, 2) * -(math.log(10000.0) / d_hid))\n",
+    "    encoding_table[:, 0::2] = torch.sin(position * div_term)\n",
+    "    encoding_table[:, 1::2] = torch.cos(position * div_term)\n",
+    "    if padding_idx is not None:\n",
+    "        encoding_table[padding_idx] = 0.\n",
+    "    return encoding_table\n",
     "\n",
-    "        print('pos_embedding inside before return : ', pos_embedding.shape)\n",
-    "        \n",
-    "        return pos_embedding\n",
-    "        \n",
-    "    def forward(self, x):\n",
-    "        # get the patches\n",
-    "        patches = self.get_patches(x, patch_size=self.patch_size)\n",
-    "        # flatten the patches\n",
-    "        patches = patches.reshape(-1, self.in_channels, self.patch_size, self.patch_size)\n",
-    "        # get the embedding\n",
-    "        embedding = self.conv(patches)\n",
-    "        # flatten the embedding\n",
-    "        embedding = embedding.reshape(-1, self.out_channels)\n",
-    "        # normalize the embedding\n",
-    "        embedding = self.norm(embedding)\n",
-    "        # add the position embedding\n",
-    "        pos_embedding = self.get_cosine_position_embedding(x, patch_size=self.patch_size)\n",
+    "seq_len = 100\n",
+    "embedding_dim = 64\n",
     "\n",
-    "        print('pos embedding: ', pos_embedding.shape)\n",
-    "        return embedding + pos_embedding\n",
+    "pos_encoding = sinusoidal_encoding_table(seq_len, embedding_dim)\n",
     "\n",
+    "print(pos_encoding)\n",
     "\n",
-    "# use the position embedding\n",
-    "patch_size = 16\n",
-    "pos_embedding = PositionEmbedding(patch_size=patch_size, in_channels=3, out_channels=8)\n",
+    "# plot the position encoding\n",
+    "fig, ax = plt.subplots(1, 1)\n",
+    "ax.imshow(pos_encoding.detach().numpy())\n",
+    "ax.axis('off')\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  1.0000,  1.0000,  1.0000],\n",
+      "         [ 0.8415,  0.6816,  0.5332,  ...,  1.0000,  1.0000,  1.0000],\n",
+      "         [ 0.9093,  0.9975,  0.9021,  ...,  1.0000,  1.0000,  1.0000],\n",
+      "         ...,\n",
+      "         [ 0.3796, -0.4645, -0.9086,  ...,  0.9997,  0.9999,  0.9999],\n",
+      "         [-0.5734, -0.9435, -0.9914,  ...,  0.9997,  0.9998,  0.9999],\n",
+      "         [-0.9992, -0.9163, -0.7687,  ...,  0.9997,  0.9998,  0.9999]]])\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(-0.5, 63.5, 99.5, -0.5)"
+      ]
+     },
+     "execution_count": 23,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "def get_pos_encoding(max_len, d_emb):\n",
+    "    pos = torch.arange(0, max_len).float().unsqueeze(1)\n",
+    "    i = torch.arange(0, d_emb, 2).float()\n",
     "\n",
-    "pos_embedding(img.unsqueeze(0))\n",
+    "    div = torch.exp(-i * math.log(10000) / d_emb)\n",
     "\n",
-    "# show the embedding and patches\n",
-    "embedding, patches = pos_embedding(img.unsqueeze(0))\n",
+    "    sin = torch.sin(pos * div)\n",
+    "    cos = torch.cos(pos * div)\n",
     "\n",
-    "print('patches: ', patches.shape)\n",
-    "print('embedding: ', embedding.shape)\n",
+    "    pos_encoding = torch.cat((sin, cos), dim=1).view(1, max_len, d_emb)\n",
     "\n",
-    "# # plot the patches\n",
-    "# no = int(32 / patch_size)\n",
-    "# fig, ax = plt.subplots(no, no)\n",
-    "# for i in range(no):\n",
-    "#     for j in range(no):\n",
-    "#         ax[i, j].imshow(patches[i * no + j, :].permute(1, 2, 0))\n",
-    "#         ax[i, j].axis('off')\n",
+    "    return pos_encoding\n",
     "\n",
-    "# # plot the embeddings\n",
-    "# no = int(32 / patch_size)\n",
-    "# fig, ax = plt.subplots(no, no)\n",
-    "# for i in range(no):\n",
-    "#     for j in range(no):\n",
-    "#         ax[i, j].imshow(embedding[i * no + j, :].detach().numpy().reshape(1, -1))\n",
-    "#         ax[i, j].axis('off')\n",
-    "    \n",
+    "seq_len = 100\n",
+    "embedding_dim = 64\n",
     "\n",
-    "# plot results of get_cosine_position_embedding\n",
-    "pos_embedding = pos_embedding.get_cosine_position_embedding(img.unsqueeze(0), patch_size=patch_size)\n",
+    "pos_encoding = get_pos_encoding(seq_len, embedding_dim)\n",
     "\n",
-    "print('pos_embedding: ', pos_embedding.shape)\n",
+    "print(pos_encoding)\n",
     "\n",
-    "# plot the embeddings\n",
-    "no = int(32 / patch_size)\n",
-    "fig, ax = plt.subplots(no, no)\n",
-    "for i in range(no):\n",
-    "    for j in range(no):\n",
-    "        ax[i, j].imshow(pos_embedding[0, :, i * no + j].detach().numpy().reshape(1, -1))\n",
-    "        ax[i, j].axis('off')\n",
-    "\n"
+    "# plot the position encoding\n",
+    "fig, ax = plt.subplots(1, 1)\n",
+    "ax.imshow(pos_encoding.squeeze().detach().numpy())\n",
+    "ax.axis('off')\n"
    ]
   }
  ],
-- 
GitLab