11.6. 自注意力与位置编码¶ 在 SageMaker Studio Lab 中打开 Notebook
在深度学习中,我们经常使用卷积神经网络(CNN)或循环神经网络(RNN)对序列进行编码。现在有了注意力机制,我们可以想象将词元的序列输入注意力池化中,从而允许序列中的任何词元(作为查询)与所有其他词元(作为键)进行匹配。在计算下一层中某个词元的表示时,该词元(作为值)可以关注(通过其查询向量)序列中的任何其他词元(基于它们的键向量)。通过使用查询-键兼容性分数的完整集合,我们可以为每个词元计算出一个表示,即通过构建对其他词元的适当加权和。由于每个词元都在关注所有其他词元(不同于解码器步骤关注编码器步骤的情况),这类架构通常被称为*自注意力*(self-attention)模型 (Lin et al., 2017, Vaswani et al., 2017),在其他地方也被称为*内部注意力*(intra-attention)模型 (Cheng et al., 2016, Parikh et al., 2016, Paulus et al., 2017)。在本节中,我们将讨论使用自注意力进行序列编码,包括为序列顺序使用额外信息。
import math
import torch
from torch import nn
from d2l import torch as d2l
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import numpy as np
import tensorflow as tf
from d2l import tensorflow as d2l
11.6.1. 自注意力¶
给定一个由输入词元组成的序列 \(\mathbf{x}_1, \ldots, \mathbf{x}_n\),其中任意 \(\mathbf{x}_i \in \mathbb{R}^d\) (\(1 \leq i \leq n\)),其自注意力输出一个同样长度的序列 \(\mathbf{y}_1, \ldots, \mathbf{y}_n\),其中:
根据 (11.1.1) 中注意力池化的定义。使用多头注意力,下面的代码片段计算了一个形状为(批量大小,时间步数或词元序列长度,\(d\))的张量的自注意力。输出张量具有相同的形状。
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
[22:10:14] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, jnp.array([3, 2])
X = jnp.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens,
training=False)[0][0],
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens, training=False),
(batch_size, num_queries, num_hiddens))
11.6.2. 比较CNN、RNN和自注意力¶
让我们比较一下将一个包含 \(n\) 个词元的序列映射到另一个等长序列的架构,其中每个输入或输出词元都由一个 \(d\) 维向量表示。具体来说,我们将考虑CNN、RNN和自注意力。我们将比较它们的计算复杂性、顺序操作和最大路径长度。请注意,顺序操作会阻碍并行计算,而任意序列位置组合之间的路径越短,学习序列中的长期依赖关系就越容易 (Hochreiter et al., 2001)。
图 11.6.1 比较CNN(省略了填充词元)、RNN和自注意力架构。¶
我们可以将任何文本序列视为一个“一维图像”。类似地,一维CNN可以处理局部特征,例如文本中的 \(n\)-grams。给定一个长度为 \(n\) 的序列,考虑一个卷积核大小为 \(k\),输入和输出通道数均为 \(d\) 的卷积层。该卷积层的计算复杂度为 \(\mathcal{O}(knd^2)\)。如 图 11.6.1 所示,CNN是分层的,因此有 \(\mathcal{O}(1)\) 个顺序操作,最大路径长度为 \(\mathcal{O}(n/k)\)。例如,在 图 11.6.1 中,\(\mathbf{x}_1\) 和 \(\mathbf{x}_5\) 位于一个核大小为3的两层CNN的感受野内。
在更新RNN的隐藏状态时,\(d \times d\) 的权重矩阵与 \(d\) 维隐藏状态的乘法计算复杂度为 \(\mathcal{O}(d^2)\)。由于序列长度为 \(n\),循环层的计算复杂度为 \(\mathcal{O}(nd^2)\)。根据 图 11.6.1,存在 \(\mathcal{O}(n)\) 个无法并行的顺序操作,最大路径长度也为 \(\mathcal{O}(n)\)。
在自注意力中,查询、键和值都是 \(n \times d\) 的矩阵。考虑在 (11.3.6) 中的缩放点积注意力,其中一个 \(n \times d\) 矩阵与一个 \(d \times n\) 矩阵相乘,然后输出的 \(n \times n\) 矩阵再与一个 \(n \times d\) 矩阵相乘。因此,自注意力的计算复杂度为 \(\mathcal{O}(n^2d)\)。正如我们从 图 11.6.1 中看到的,每个词元都通过自注意力直接连接到任何其他词元。因此,计算可以并行进行,顺序操作为 \(\mathcal{O}(1)\),最大路径长度也为 \(\mathcal{O}(1)\)。
总而言之,CNN和自注意力都享有并行计算的优势,并且自注意力的最大路径长度最短。然而,其相对于序列长度的二次计算复杂度使得自注意力在处理非常长的序列时速度慢得令人望而却步。
11.6.3. 位置编码¶
与RNN逐个递归处理序列中的词元不同,自注意力为了并行计算而放弃了顺序操作。请注意,自注意力本身不保留序列的顺序。如果模型确实需要知道输入序列的顺序,我们该怎么办?
保留词元顺序信息的主流方法是将其表示为与每个词元关联的附加输入。这些输入称为*位置编码*,它们既可以学习得到,也可以事先固定。我们现在描述一种基于正弦和余弦函数的固定位置编码的简单方案 (Vaswani et al., 2017)。
假设输入表示 \(\mathbf{X} \in \mathbb{R}^{n \times d}\) 包含了一个序列中 \(n\) 个词元的 \(d\) 维嵌入。位置编码使用一个形状相同的位置嵌入矩阵 \(\mathbf{P} \in \mathbb{R}^{n \times d}\) 输出 \(\mathbf{X} + \mathbf{P}\),该矩阵第 \(i\) 行和第 \((2j)\) 或 \((2j + 1)\) 列上的元素为:
乍一看,这种三角函数的设计看起来很奇怪。在我们解释这个设计之前,我们先在下面的 `PositionalEncoding` 类中实现它。
class PositionalEncoding(nn.Module): #@save
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
class PositionalEncoding(nn.Block): #@save
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = np.zeros((1, max_len, num_hiddens))
X = np.arange(max_len).reshape(-1, 1) / np.power(
10000, np.arange(0, num_hiddens, 2) / num_hiddens)
self.P[:, :, 0::2] = np.sin(X)
self.P[:, :, 1::2] = np.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].as_in_ctx(X.ctx)
return self.dropout(X)
class PositionalEncoding(nn.Module): #@save
"""Positional encoding."""
num_hiddens: int
dropout: float
max_len: int = 1000
def setup(self):
# Create a long enough P
self.P = jnp.zeros((1, self.max_len, self.num_hiddens))
X = jnp.arange(self.max_len, dtype=jnp.float32).reshape(
-1, 1) / jnp.power(10000, jnp.arange(
0, self.num_hiddens, 2, dtype=jnp.float32) / self.num_hiddens)
self.P = self.P.at[:, :, 0::2].set(jnp.sin(X))
self.P = self.P.at[:, :, 1::2].set(jnp.cos(X))
@nn.compact
def __call__(self, X, training=False):
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'P', self.P)
X = X + self.P[:, :X.shape[1], :]
return nn.Dropout(self.dropout)(X, deterministic=not training)
class PositionalEncoding(tf.keras.layers.Layer): #@save
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = tf.keras.layers.Dropout(dropout)
# Create a long enough P
self.P = np.zeros((1, max_len, num_hiddens))
X = np.arange(max_len, dtype=np.float32).reshape(
-1,1)/np.power(10000, np.arange(
0, num_hiddens, 2, dtype=np.float32) / num_hiddens)
self.P[:, :, 0::2] = np.sin(X)
self.P[:, :, 1::2] = np.cos(X)
def call(self, X, **kwargs):
X = X + self.P[:, :X.shape[1], :]
return self.dropout(X, **kwargs)
在位置嵌入矩阵 \(\mathbf{P}\) 中,行对应序列中的位置,列代表不同的位置编码维度。在下面的例子中,我们可以看到位置嵌入矩阵的第 \(6\) 列和第 \(7\) 列比第 \(8\) 列和第 \(9\) 列具有更高的频率。第 \(6\) 列和第 \(7\) 列(以及第 \(8\) 列和第 \(9\) 列)之间的偏移是由于正弦和余弦函数的交替。
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.initialize()
X = pos_encoding(np.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
params = pos_encoding.init(d2l.get_key(), jnp.zeros((1, num_steps, encoding_dim)))
X, inter_vars = pos_encoding.apply(params, jnp.zeros((1, num_steps, encoding_dim)),
mutable='intermediates')
P = inter_vars['intermediates']['P'][0] # retrieve intermediate value P
P = P[:, :X.shape[1], :]
d2l.plot(jnp.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in jnp.arange(6, 10)])
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(tf.zeros((1, num_steps, encoding_dim)), training=False)
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])
11.6.3.1. 绝对位置信息¶
为了了解沿编码维度单调递减的频率如何与绝对位置信息相关联,让我们打印出 \(0, 1, \ldots, 7\) 的二进制表示。正如我们所看到的,最低位、次低位和第三低位分别在每个数字、每两个数字和每四个数字上交替。
for i in range(8):
print(f'{i} in binary is {i:>03b}')
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
for i in range(8):
print(f'{i} in binary is {i:>03b}')
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
for i in range(8):
print(f'{i} in binary is {i:>03b}')
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
for i in range(8):
print(f'{i} in binary is {i:>03b}')
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
在二进制表示中,较高位比较低位具有更低的频率。类似地,如下面的热图所示,位置编码通过使用三角函数来降低沿编码维度的频率。由于输出是浮点数,这种连续表示比二进制表示更节省空间。
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
P = np.expand_dims(np.expand_dims(P[0, :, :], 0), 0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
P = jnp.expand_dims(jnp.expand_dims(P[0, :, :], axis=0), axis=0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
P = tf.expand_dims(tf.expand_dims(P[0, :, :], axis=0), axis=0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
11.6.3.2. 相对位置信息¶
除了捕获绝对位置信息外,上述位置编码还允许模型轻松地学习按相对位置进行关注。这是因为对于任何固定的位置偏移 \(\delta\),位置 \(i + \delta\) 的位置编码可以通过位置 \(i\) 的线性投影来表示。
这个投影可以用数学方法来解释。记 \(\omega_j = 1/10000^{2j/d}\),则 (11.6.2) 中的任意一对 \((p_{i, 2j}, p_{i, 2j+1})\) 都可以线性投影到 \((p_{i+\delta, 2j}, p_{i+\delta, 2j+1})\),对于任意固定的偏移量 \(\delta\):
其中 \(2\times 2\) 投影矩阵不依赖于任何位置索引 \(i\)。
11.6.4. 小结¶
在自注意力中,查询、键和值都来自同一个地方。CNN和自注意力都享有并行计算的优势,并且自注意力的最大路径长度最短。然而,其相对于序列长度的二次计算复杂度使得自注意力在处理非常长的序列时速度慢得令人望而却步。为了使用序列顺序信息,我们可以通过向输入表示中添加位置编码来注入绝对或相对位置信息。
11.6.5. 练习¶
假设我们设计一个深度架构,通过堆叠带有位置编码的自注意力层来表示一个序列。可能会出现什么问题?
你能设计一个可学习的位置编码方法吗?
我们能否根据自注意力中比较的查询和键之间的不同偏移量,分配不同的学习嵌入?提示:你可以参考相对位置嵌入 (Huang et al., 2018, Shaw et al., 2018)。