11.8. 用于视觉的Transformer
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

Transformer架构最初是为序列到序列的学习而提出的,重点是机器翻译。随后,Transformer成为各种自然语言处理任务的首选模型 (Brown et al., 2020, Devlin et al., 2018, Radford et al., 2018, Radford et al., 2019, Raffel et al., 2020)。然而,在计算机视觉领域,主流架构仍然是CNN (第 8 节)。很自然地,研究人员开始思考,是否有可能通过将Transformer模型应用于图像数据来做得更好。这个问题激发了计算机视觉社区的巨大兴趣。最近,Ramachandran et al. (2019) 提出了一种用自注意力代替卷积的方案。然而,它在注意力中使用了专门的模式,这使得在硬件加速器上扩展模型变得困难。然后,Cordonnier et al. (2020) 从理论上证明了自注意力可以学习到与卷积类似的行为。在经验上,\(2 \times 2\) 的图像块被用作输入,但小的图像块尺寸使得该模型仅适用于低分辨率的图像数据。

在没有对图像块大小进行特定限制的情况下,视觉Transformer(ViT)从图像中提取图像块,并将它们输入Transformer编码器以获得全局表示,最终将其转换为用于分类 (Dosovitskiy et al., 2021)。值得注意的是,Transformer比CNN表现出更好的可扩展性:在更大的数据集上训练更大的模型时,视觉Transformer的性能明显优于ResNet。与自然语言处理中的网络架构设计领域类似,Transformer也成为计算机视觉领域的游戏规则改变者。

import torch
from torch import nn
from d2l import torch as d2l
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l

11.8.1. 模型

图 11.8.1 描述了视觉Transformer的模型架构。该架构由一个将图像分块的主干、一个基于多层Transformer编码器的主体和一个将全局表示转换为输出标签的头部组成。

../_images/vit.svg

图 11.8.1 视觉Transformer架构。在这个例子中,一张图像被分割成九个图像块。一个特殊的“<cls>”词元和九个展平的图像块通过图像块嵌入和 \(\mathit{n}\) 个Transformer编码器块分别转换为十个表示。 “<cls>”表示被进一步转换为输出标签。

考虑一个高度为 \(h\)、宽度为 \(w\)\(c\) 个通道的输入图像。将图像块的高度和宽度都指定为 \(p\),图像被分割成一个由 \(m = hw/p^2\) 个图像块组成的序列,其中每个图像块被展平为一个长度为 \(cp^2\) 的向量。这样,图像块可以像文本序列中的词元一样被Transformer编码器处理。一个特殊的“<cls>”(类)词元和 \(m\) 个展平的图像块被线性投影到一个由 \(m+1\) 个向量组成的序列中,并与可学习的位置嵌入相加。多层Transformer编码器将 \(m+1\) 个输入向量转换为相同数量、相同长度的输出向量表示。它的工作方式与 图 11.7.1 中的原始Transformer编码器完全相同,只是在归一化的位置上有所不同。由于“<cls>”词元通过自注意力机制关注所有的图像块(参见 图 11.6.1),它从Transformer编码器输出的表示将被进一步转换为输出标签。

11.8.2. 图像块嵌入

为了实现一个视觉Transformer,我们从 图 11.8.1 中的图像块嵌入开始。将图像分割成图像块并将这些展平的图像块进行线性投影,可以简化为单个卷积操作,其中核大小和步幅大小都设置为图像块大小。

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=96, patch_size=16, num_hiddens=512):
        super().__init__()
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x
        img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) * (
            img_size[1] // patch_size[1])
        self.conv = nn.LazyConv2d(num_hiddens, kernel_size=patch_size,
                                  stride=patch_size)

    def forward(self, X):
        # Output shape: (batch size, no. of patches, no. of channels)
        return self.conv(X).flatten(2).transpose(1, 2)
class PatchEmbedding(nn.Module):
    img_size: int = 96
    patch_size: int = 16
    num_hiddens: int = 512

    def setup(self):
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x
        img_size, patch_size = _make_tuple(self.img_size), _make_tuple(self.patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) * (
            img_size[1] // patch_size[1])
        self.conv = nn.Conv(self.num_hiddens, kernel_size=patch_size,
                            strides=patch_size, padding='SAME')

    def __call__(self, X):
        # Output shape: (batch size, no. of patches, no. of channels)
        X = self.conv(X)
        return X.reshape((X.shape[0], -1, X.shape[3]))

在下面的示例中,以高度和宽度为 img_size 的图像作为输入,图像块嵌入输出 (img_size//patch_size)**2 个图像块,这些图像块被线性投影到长度为 num_hiddens 的向量。

img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens)
X = torch.zeros(batch_size, 3, img_size, img_size)
d2l.check_shape(patch_emb(X),
                (batch_size, (img_size//patch_size)**2, num_hiddens))
img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens)
X = jnp.zeros((batch_size, img_size, img_size, 3))
output, _ = patch_emb.init_with_output(d2l.get_key(), X)
d2l.check_shape(output, (batch_size, (img_size//patch_size)**2, num_hiddens))

11.8.3. 视觉Transformer编码器

视觉Transformer编码器的MLP与原始Transformer编码器的逐位 FFN 略有不同(参见 11.7.2 节)。首先,这里的激活函数使用高斯误差线性单元(GELU),可以看作是 ReLU 的更平滑版本 (Hendrycks and Gimpel, 2016)。其次,在MLP中每个全连接层的输出上应用 dropout 进行正则化。

class ViTMLP(nn.Module):
    def __init__(self, mlp_num_hiddens, mlp_num_outputs, dropout=0.5):
        super().__init__()
        self.dense1 = nn.LazyLinear(mlp_num_hiddens)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.dense2 = nn.LazyLinear(mlp_num_outputs)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout2(self.dense2(self.dropout1(self.gelu(
            self.dense1(x)))))
class ViTMLP(nn.Module):
    mlp_num_hiddens: int
    mlp_num_outputs: int
    dropout: float = 0.5

    @nn.compact
    def __call__(self, x, training=False):
        x = nn.Dense(self.mlp_num_hiddens)(x)
        x = nn.gelu(x)
        x = nn.Dropout(self.dropout, deterministic=not training)(x)
        x = nn.Dense(self.mlp_num_outputs)(x)
        x = nn.Dropout(self.dropout, deterministic=not training)(x)
        return x

视觉Transformer编码器块的实现遵循了 图 11.8.1 中的预归一化设计,其中归一化在多头注意力或MLP之前应用。与后归一化(图 11.7.1 中的“add & norm”),即在残差连接之后放置归一化相比,预归一化可以为Transformer带来更有效或高效的训练 (Baevski and Auli, 2018, Wang et al., 2019, Xiong et al., 2020)

class ViTBlock(nn.Module):
    def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens,
                 num_heads, dropout, use_bias=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(norm_shape)
        self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
                                                dropout, use_bias)
        self.ln2 = nn.LayerNorm(norm_shape)
        self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)

    def forward(self, X, valid_lens=None):
        X = X + self.attention(*([self.ln1(X)] * 3), valid_lens)
        return X + self.mlp(self.ln2(X))
class ViTBlock(nn.Module):
    num_hiddens: int
    mlp_num_hiddens: int
    num_heads: int
    dropout: float
    use_bias: bool = False

    def setup(self):
        self.attention = d2l.MultiHeadAttention(self.num_hiddens, self.num_heads,
                                                self.dropout, self.use_bias)
        self.mlp = ViTMLP(self.mlp_num_hiddens, self.num_hiddens, self.dropout)

    @nn.compact
    def __call__(self, X, valid_lens=None, training=False):
        X = X + self.attention(*([nn.LayerNorm()(X)] * 3),
                               valid_lens, training=training)[0]
        return X + self.mlp(nn.LayerNorm()(X), training=training)

11.7.4节 中一样,视觉Transformer编码器块不会改变其输入形状。

X = torch.ones((2, 100, 24))
encoder_blk = ViTBlock(24, 24, 48, 8, 0.5)
encoder_blk.eval()
d2l.check_shape(encoder_blk(X), X.shape)
X = jnp.ones((2, 100, 24))
encoder_blk = ViTBlock(24, 48, 8, 0.5)
d2l.check_shape(encoder_blk.init_with_output(d2l.get_key(), X)[0], X.shape)

11.8.4. 整合

下面视觉Transformer的前向传播过程非常直接。首先,输入图像被送入一个 PatchEmbedding 实例,其输出与“<cls>”词元嵌入连接起来。它们与可学习的位置嵌入相加,然后进行dropout。接着,输出被送入堆叠了 num_blksViTBlock 类实例的Transformer编码器。最后, “<cls>”词元的表示被网络头部投影。

class ViT(d2l.Classifier):
    """Vision Transformer."""
    def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens,
                 num_heads, num_blks, emb_dropout, blk_dropout, lr=0.1,
                 use_bias=False, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.patch_embedding = PatchEmbedding(
            img_size, patch_size, num_hiddens)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
        num_steps = self.patch_embedding.num_patches + 1  # Add the cls token
        # Positional embeddings are learnable
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_steps, num_hiddens))
        self.dropout = nn.Dropout(emb_dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module(f"{i}", ViTBlock(
                num_hiddens, num_hiddens, mlp_num_hiddens,
                num_heads, blk_dropout, use_bias))
        self.head = nn.Sequential(nn.LayerNorm(num_hiddens),
                                  nn.Linear(num_hiddens, num_classes))

    def forward(self, X):
        X = self.patch_embedding(X)
        X = torch.cat((self.cls_token.expand(X.shape[0], -1, -1), X), 1)
        X = self.dropout(X + self.pos_embedding)
        for blk in self.blks:
            X = blk(X)
        return self.head(X[:, 0])
class ViT(d2l.Classifier):
    """Vision Transformer."""
    img_size: int
    patch_size: int
    num_hiddens: int
    mlp_num_hiddens: int
    num_heads: int
    num_blks: int
    emb_dropout: float
    blk_dropout: float
    lr: float = 0.1
    use_bias: bool = False
    num_classes: int = 10
    training: bool = False

    def setup(self):
        self.patch_embedding = PatchEmbedding(self.img_size, self.patch_size,
                                              self.num_hiddens)
        self.cls_token = self.param('cls_token', nn.initializers.zeros,
                                    (1, 1, self.num_hiddens))
        num_steps = self.patch_embedding.num_patches + 1  # Add the cls token
        # Positional embeddings are learnable
        self.pos_embedding = self.param('pos_embed', nn.initializers.normal(),
                                        (1, num_steps, self.num_hiddens))
        self.blks = [ViTBlock(self.num_hiddens, self.mlp_num_hiddens,
                              self.num_heads, self.blk_dropout, self.use_bias)
                    for _ in range(self.num_blks)]
        self.head = nn.Sequential([nn.LayerNorm(), nn.Dense(self.num_classes)])

    @nn.compact
    def __call__(self, X):
        X = self.patch_embedding(X)
        X = jnp.concatenate((jnp.tile(self.cls_token, (X.shape[0], 1, 1)), X), 1)
        X = nn.Dropout(emb_dropout, deterministic=not self.training)(X + self.pos_embedding)
        for blk in self.blks:
            X = blk(X, training=self.training)
        return self.head(X[:, 0])

11.8.5. 训练

在Fashion-MNIST数据集上训练视觉Transformer就像在 第 8 节 中训练CNN一样。

img_size, patch_size = 96, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(img_size, img_size))
trainer.fit(model, data)
../_images/output_vision-transformer_a17f42_66_0.svg
img_size, patch_size = 96, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(img_size, img_size))
trainer.fit(model, data)
../_images/output_vision-transformer_a17f42_69_0.svg

11.8.6. 总结与讨论

你可能已经注意到,对于像Fashion-MNIST这样的小型数据集,我们实现的视觉Transformer的性能并不优于 第 8.6 节 中的ResNet。即使在ImageNet数据集(120万张图像)上也可以观察到类似的现象。这是因为Transformer缺乏卷积中的那些有用原则,如平移不变性和局部性(第 7.1 节)。然而,当在更大的数据集(例如3亿张图像)上训练更大的模型时,情况就变了,视觉Transformer在图像分类方面以很大的优势超过了ResNet,展示了Transformer在可扩展性方面的内在优势 (Dosovitskiy et al., 2021)。视觉Transformer的引入改变了图像数据建模的网络设计格局。它们很快被证明在ImageNet数据集上,通过DeiT的数据高效训练策略是有效的 (Touvron et al., 2021)。然而,自注意力的二次复杂度(第 11.6 节)使得Transformer架构不太适合高分辨率图像。为了成为计算机视觉中通用的骨干网络,Swin Transformer通过移位窗口解决了关于图像大小的二次计算复杂度问题(第 11.6.2 节),并恢复了类似卷积的先验知识,将Transformer的适用性扩展到图像分类之外的一系列计算机视觉任务,并取得了最先进的结果 (Liu et al., 2021)

11.8.7. 练习

  1. img_size 的值如何影响训练时间?

  2. 如果不将“<cls>”词元表示投影到输出,你会如何投影平均的图像块表示?实现这个改变,看看它如何影响准确性。

  3. 你能修改超参数来提高视觉Transformer的准确性吗?