分享

ViTPatchEmbedding理解

 520jefferson 2023-04-19 发布于中国香港

ViT(Vision Transformer)中的Patch Embedding用于将原始的2维图像转换成一系列的1维patch embeddings。

假设输入图像的维度为HxWxC,分别表示高,宽和通道数。

Patch Embeeding操作将输入图像分成N个大小为P^2C的patch,并reshape成维度为Nx(P^2C)的patches块x_{p},x_{p}\in \mathbb{R}^{N\times \left ( P^2\cdot C\right )}。其中N=\frac{HW}{P^2},表示分别在二维图像的宽和高上按P进行划分,每个patch块的维度为P^2C​​​​​​,再通过线性变换将patches投影到维度为D的空间上,也就是直接将原来大小为HxWxC的二维图像展平成N个大小为P^2C的一维向量x_{p}^{'}x_{p}^{'}\in \mathbb{R}^{N\times D}

上述的操作等价于对输入图像HxWxC执行一个内核大小为PxP,步长为P的卷积操作(虽然等价,但是ViT逻辑上并不包含任何卷积操作)。

卷积的输出计算公式为\left \lfloor \frac{n+2p-f}{s}+1 \right \rfloor,将输入图像的宽和高分别带入得到

\left \lfloor \frac{H+0-P}{P}+1 \right \rfloor = \left \lfloor \frac{H}{P}\right \rfloor\left \lfloor \frac{W+0-P}{P}+1 \right \rfloor = \left \lfloor \frac{W}{P}\right \rfloor,相乘之后就得到N,等价于将输入图像划分成N个大小为P^2C的patch块。

代码如下:

  1. class PatchEmbed(nn.Module):
  2. """
  3. Image to Patch Embedding
  4. """
  5. def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
  6. super().__init__()
  7. img_size = (img_size, img_size)
  8. patch_size = (patch_size, patch_size)
  9. num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
  10. self.img_size = img_size
  11. self.patch_size = patch_size
  12. self.num_patches = num_patches
  13. #
  14. # embed_dim表示切好的图片拉成一维向量后的特征长度
  15. #
  16. # 图像共切分为N = HW/P^2个patch块
  17. # 在实现上等同于对reshape后的patch序列进行一个PxP且stride为P的卷积操作
  18. # output = {[(n+2p-f)/s + 1]向下取整}^2
  19. # 即output = {[(n-P)/P + 1]向下取整}^2 = (n/P)^2
  20. #
  21. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  22. def forward(self, x):
  23. B, C, H, W = x.shape
  24. assert H == self.img_size[0] and W == self.img_size[1], \
  25. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  26. x = self.proj(x).flatten(2).transpose(1, 2)
  27. return x # x.shape is [8, 196, 768]

其中卷积操作self.proj之后接着一步flatten(2)展平操作,表示将patch投影到维度为D=P^2的空间上。最后进行转置操作,表示输入图像经过转换后生成长度为196(14*14,表示共有196个patches),维度为768(3*16*16)的特征向量。

参考:

"未来"的经典之作ViT:transformer is all you need! - 知乎

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多