"Cell \u001b[0;32mIn[6], line 68\u001b[0m\n\u001b[1;32m 65\u001b[0m patch_size \u001b[39m=\u001b[39m \u001b[39m16\u001b[39m\n\u001b[1;32m 66\u001b[0m pos_embedding \u001b[39m=\u001b[39m PositionEmbedding(patch_size\u001b[39m=\u001b[39mpatch_size, in_channels\u001b[39m=\u001b[39m\u001b[39m3\u001b[39m, out_channels\u001b[39m=\u001b[39m\u001b[39m8\u001b[39m)\n\u001b[0;32m---> 68\u001b[0m pos_embedding(img\u001b[39m.\u001b[39;49munsqueeze(\u001b[39m0\u001b[39;49m))\n\u001b[1;32m 70\u001b[0m \u001b[39m# show the embedding and patches\u001b[39;00m\n\u001b[1;32m 71\u001b[0m embedding, patches \u001b[39m=\u001b[39m pos_embedding(img\u001b[39m.\u001b[39munsqueeze(\u001b[39m0\u001b[39m))\n",
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1195\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"Cell \u001b[0;32mIn[6], line 58\u001b[0m, in \u001b[0;36mPositionEmbedding.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 56\u001b[0m embedding \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnorm(embedding)\n\u001b[1;32m 57\u001b[0m \u001b[39m# add the position embedding\u001b[39;00m\n\u001b[0;32m---> 58\u001b[0m pos_embedding \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mget_cosine_position_embedding(x, patch_size\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpatch_size)\n\u001b[1;32m 60\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mpos embedding: \u001b[39m\u001b[39m'\u001b[39m, pos_embedding\u001b[39m.\u001b[39mshape)\n\u001b[1;32m 61\u001b[0m \u001b[39mreturn\u001b[39;00m embedding \u001b[39m+\u001b[39m pos_embedding\n",
"Cell \u001b[0;32mIn[6], line 38\u001b[0m, in \u001b[0;36mPositionEmbedding.get_cosine_position_embedding\u001b[0;34m(self, x, patch_size)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m'\u001b[39m\u001b[39mpos_embedding inside : \u001b[39m\u001b[39m'\u001b[39m, pos_embedding\u001b[39m.\u001b[39mshape)\n\u001b[1;32m 37\u001b[0m \u001b[39m# get the sine and cosine embedding\u001b[39;00m\n\u001b[0;32m---> 38\u001b[0m pos_embedding \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mcat([torch\u001b[39m.\u001b[39;49msin(pos_embedding), torch\u001b[39m.\u001b[39;49mcos(pos_embedding)], dim\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m)\u001b[39m.\u001b[39;49mreshape(\u001b[39m1\u001b[39;49m, no_patches, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mout_channels)\u001b[39m.\u001b[39mpermute(\u001b[39m0\u001b[39m, \u001b[39m2\u001b[39m, \u001b[39m1\u001b[39m)\n\u001b[1;32m 39\u001b[0m \u001b[39m# expand the position embedding\u001b[39;00m\n\u001b[1;32m 40\u001b[0m pos_embedding \u001b[39m=\u001b[39m pos_embedding\u001b[39m.\u001b[39mexpand(batch_size, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n",
"\u001b[0;31mTypeError\u001b[0m: reshape(): argument 'shape' must be tuple of SymInts, but found element of type float at pos 2"