11.3. 注意力评分函数¶ 在 SageMaker Studio Lab 中打开 Notebook
在 第 11.2 节中,我们使用了一些基于距离的核(包括高斯核)来对查询和键之间的交互进行建模。事实证明,距离函数的计算成本比点积略高。因此,在使用 softmax 操作来确保非负的注意力权重时,大部分工作都集中在了计算更简单的*注意力评分函数*(attention scoring functions)\(a\)上,见 (11.1.3) 和 图 11.3.1。
图 11.3.1 计算注意力汇聚的输出,作为值的加权平均,其中权重是通过注意力评分函数 \(\mathit{a}\) 和 softmax 操作计算的。¶
import math
import torch
from torch import nn
from d2l import torch as d2l
import math
from mxnet import np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
import math
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l
11.3.1. 点积注意力¶
让我们回顾一下高斯核的注意力函数(不带指数运算):
首先,注意最后一项只依赖于 \(\mathbf{q}\)。因此,它对于所有的 \((\mathbf{q}, \mathbf{k}_i)\) 对都是相同的。如 (11.1.3) 中所述,将注意力权重归一化为 \(1\) 可以确保该项完全消失。其次,注意批量规范化和层规范化(稍后讨论)都会导致激活值具有良好界定且通常为常数的范数 \(\|\mathbf{k}_i\|\)。例如,当键 \(\mathbf{k}_i\) 是由层规范化生成时,情况就是如此。因此,我们可以从 \(a\) 的定义中去掉它,而不会对结果产生重大影响。
最后,我们需要控制指数函数中参数的量级。假设查询 \(\mathbf{q} \in \mathbb{R}^d\) 和键 \(\mathbf{k}_i \in \mathbb{R}^d\) 的所有元素都是独立同分布的随机变量,均值为零,方差为单位。这两个向量的点积均值为零,方差为 \(d\)。为确保无论向量长度如何,点积的方差仍然保持为 \(1\),我们使用*缩放点积注意力*(scaled dot product attention)评分函数。也就是说,我们用 \(1/\sqrt{d}\) 来重新缩放点积。这样,我们就得到了第一个常用的注意力函数,它被用于例如 Transformers (Vaswani et al., 2017) 中:
请注意,注意力权重 \(\alpha\) 仍然需要归一化。我们可以通过使用 softmax 操作,借助 (11.1.3) 进一步简化这一点:
事实证明,所有流行的注意力机制都使用 softmax,因此在本章的其余部分,我们将仅限于此。
11.3.2. 便利函数¶
我们需要一些函数来使注意力机制能够高效部署。这包括处理可变长度字符串的工具(在自然语言处理中很常见)以及在小批量上进行高效评估的工具(批量矩阵乘法)。
11.3.2.1. 掩码 Softmax 操作¶
注意力机制最流行的应用之一是序列模型。因此,我们需要能够处理不同长度的序列。在某些情况下,这些序列可能最终出现在同一个小批量中,需要为较短的序列填充虚拟标记(参见 第 10.5 节 中的示例)。这些特殊标记不携带任何意义。例如,假设我们有以下三个句子:
Dive into Deep Learning
Learn to code <blank>
Hello world <blank> <blank>
由于我们不希望在注意力模型中出现空白,我们只需要将 \(\sum_{i=1}^n \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i\) 限制在 \(\sum_{i=1}^l \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i\) 的范围内,其中 \(l \leq n\) 是实际句子的长度。由于这是一个非常常见的问题,它有一个名字:*掩码 softmax 操作*(masked softmax operation)。
我们来实现它。实际上,这个实现稍微作弊,通过将 \(i > l\) 的 \(\mathbf{v}_i\) 的值设置为零。此外,它将注意力权重设置为一个很大的负数,例如 \(-10^{6}\),以便在实践中使其对梯度和值的贡献消失。这样做是因为线性代数内核和运算符在 GPU 上经过了大量优化,在计算上稍微浪费一些比使用带有条件(if-then-else)语句的代码要快。
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.size(1)
mask = torch.arange((maxlen), dtype=torch.float32,
device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
if valid_lens is None:
return npx.softmax(X)
else:
shape = X.shape
if valid_lens.ndim == 1:
valid_lens = valid_lens.repeat(shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True,
value=-1e6, axis=1)
return npx.softmax(X).reshape(shape)
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.shape[1]
mask = jnp.arange((maxlen),
dtype=jnp.float32)[None, :] < valid_len[:, None]
return jnp.where(mask, X, value)
if valid_lens is None:
return nn.softmax(X, axis=-1)
else:
shape = X.shape
if valid_lens.ndim == 1:
valid_lens = jnp.repeat(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.softmax(X.reshape(shape), axis=-1)
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.shape[1]
mask = tf.range(start=0, limit=maxlen, dtype=tf.float32)[
None, :] < tf.cast(valid_len[:, None], dtype=tf.float32)
if len(X.shape) == 3:
return tf.where(tf.expand_dims(mask, axis=-1), X, value)
else:
return tf.where(mask, X, value)
if valid_lens is None:
return tf.nn.softmax(X, axis=-1)
else:
shape = X.shape
if len(valid_lens.shape) == 1:
valid_lens = tf.repeat(valid_lens, repeats=shape[1])
else:
valid_lens = tf.reshape(valid_lens, shape=-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(tf.reshape(X, shape=(-1, shape[-1])), valid_lens,
value=-1e6)
return tf.nn.softmax(tf.reshape(X, shape=shape), axis=-1)
为了说明这个函数是如何工作的,考虑一个包含两个大小为 \(2 \times 4\) 的样本的小批量,它们的有效长度分别为 \(2\) 和 \(3\)。经过掩码 softmax 操作后,每对向量超出有效长度的值都被掩码为零。
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
tensor([[[0.4448, 0.5552, 0.0000, 0.0000],
[0.4032, 0.5968, 0.0000, 0.0000]],
[[0.2795, 0.2805, 0.4400, 0.0000],
[0.2798, 0.3092, 0.4110, 0.0000]]])
masked_softmax(np.random.uniform(size=(2, 2, 4)), np.array([2, 3]))
[22:05:24] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
array([[[0.488994 , 0.511006 , 0. , 0. ],
[0.43654838, 0.56345165, 0. , 0. ]],
[[0.28817102, 0.3519408 , 0.3598882 , 0. ],
[0.29034293, 0.25239873, 0.45725834, 0. ]]])
masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)), jnp.array([2, 3]))
Array([[[0.2914798 , 0.7085202 , 0. , 0. ],
[0.5130609 , 0.48693904, 0. , 0. ]],
[[0.17453432, 0.4599773 , 0.36548832, 0. ],
[0.3574293 , 0.3150612 , 0.32750952, 0. ]]], dtype=float32)
masked_softmax(tf.random.uniform(shape=(2, 2, 4)), tf.constant([2, 3]))
<tf.Tensor: shape=(2, 2, 4), dtype=float32, numpy=
array([[[0.44247434, 0.55752563, 0. , 0. ],
[0.4283622 , 0.5716378 , 0. , 0. ]],
[[0.3462756 , 0.47043496, 0.18328945, 0. ],
[0.28437027, 0.37555325, 0.34007645, 0. ]]], dtype=float32)>
如果我们需要更精细的控制来为每个样本的两个向量指定有效长度,我们只需使用一个二维的有效长度张量。这会产生:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
[0.4109, 0.2794, 0.3097, 0.0000]],
[[0.3960, 0.6040, 0.0000, 0.0000],
[0.2557, 0.1833, 0.2420, 0.3190]]])
masked_softmax(np.random.uniform(size=(2, 2, 4)),
np.array([[1, 3], [2, 4]]))
array([[[1. , 0. , 0. , 0. ],
[0.35848376, 0.36588794, 0.2756283 , 0. ]],
[[0.54370314, 0.45629686, 0. , 0. ],
[0.19598779, 0.25580424, 0.19916737, 0.34904057]]])
masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)),
jnp.array([[1, 3], [2, 4]]))
Array([[[1. , 0. , 0. , 0. ],
[0.31556115, 0.28214547, 0.40229338, 0. ]],
[[0.5613054 , 0.43869466, 0. , 0. ],
[0.29578257, 0.20095006, 0.2151548 , 0.28811258]]], dtype=float32)
masked_softmax(tf.random.uniform((2, 2, 4)), tf.constant([[1, 3], [2, 4]]))
<tf.Tensor: shape=(2, 2, 4), dtype=float32, numpy=
array([[[1. , 0. , 0. , 0. ],
[0.28053525, 0.3656918 , 0.35377294, 0. ]],
[[0.54060036, 0.4593996 , 0. , 0. ],
[0.25626186, 0.22631206, 0.1530483 , 0.36437777]]], dtype=float32)>
11.3.2.2. 批量矩阵乘法¶
另一个常用的操作是批量矩阵相乘。当有小批量的查询、键和值时,这非常方便。更具体地说,假设
那么批量矩阵乘法(BMM)计算逐元素的乘积:
让我们在深度学习框架中看看它的实际应用。
Q = torch.ones((2, 3, 4))
K = torch.ones((2, 4, 6))
d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))
Q = np.ones((2, 3, 4))
K = np.ones((2, 4, 6))
d2l.check_shape(npx.batch_dot(Q, K), (2, 3, 6))
Q = jnp.ones((2, 3, 4))
K = jnp.ones((2, 4, 6))
d2l.check_shape(jax.lax.batch_matmul(Q, K), (2, 3, 6))
Q = tf.ones((2, 3, 4))
K = tf.ones((2, 4, 6))
d2l.check_shape(tf.matmul(Q, K).numpy(), (2, 3, 6))
11.3.3. 缩放点积注意力¶
让我们回到 (11.3.2) 中介绍的点积注意力。一般来说,它要求查询和键具有相同的向量长度,比如 \(d\),尽管这可以通过用 \(\mathbf{q}^\top \mathbf{M} \mathbf{k}\) 替换 \(\mathbf{q}^\top \mathbf{k}\) 来轻松解决,其中 \(\mathbf{M}\) 是一个为在两个空间之间转换而适当选择的矩阵。目前,我们假设维度匹配。
在实践中,为了提高效率,我们通常考虑小批量,例如为 \(n\) 个查询和 \(m\) 个键值对计算注意力,其中查询和键的长度为 \(d\),值的长度为 \(v\)。查询 \(\mathbf Q\in\mathbb R^{n\times d}\)、键 \(\mathbf K\in\mathbb R^{m\times d}\) 和值 \(\mathbf V\in\mathbb R^{m\times v}\) 的缩放点积注意力因此可以写成:
请注意,当将其应用于小批量时,我们需要 (11.3.5) 中介绍的批量矩阵乘法。在以下缩放点积注意力的实现中,我们使用 dropout 进行模型正则化。
class DotProductAttention(nn.Module): #@save
"""Scaled dot product attention."""
def __init__(self, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Swap the last two dimensions of keys with keys.transpose(1, 2)
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
class DotProductAttention(nn.Block): #@save
"""Scaled dot product attention."""
def __init__(self, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Set transpose_b=True to swap the last two dimensions of keys
scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return npx.batch_dot(self.dropout(self.attention_weights), values)
class DotProductAttention(nn.Module): #@save
"""Scaled dot product attention."""
dropout: float
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
@nn.compact
def __call__(self, queries, keys, values, valid_lens=None,
training=False):
d = queries.shape[-1]
# Swap the last two dimensions of keys with keys.swapaxes(1, 2)
scores = queries@(keys.swapaxes(1, 2)) / math.sqrt(d)
attention_weights = masked_softmax(scores, valid_lens)
dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
return dropout_layer(attention_weights)@values, attention_weights
class DotProductAttention(tf.keras.layers.Layer): #@save
"""Scaled dot product attention."""
def __init__(self, dropout):
super().__init__()
self.dropout = tf.keras.layers.Dropout(dropout)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
def call(self, queries, keys, values, valid_lens=None, **kwargs):
d = queries.shape[-1]
scores = tf.matmul(queries, keys, transpose_b=True)/tf.math.sqrt(
tf.cast(d, dtype=tf.float32))
self.attention_weights = masked_softmax(scores, valid_lens)
return tf.matmul(self.dropout(self.attention_weights, **kwargs), values)
为了说明 DotProductAttention
类是如何工作的,我们使用之前加性注意力的玩具示例中的相同键、值和有效长度。在我们的示例中,我们假设小批量大小为 \(2\),总共有 \(10\) 个键和值,值的维度为 \(4\)。最后,我们假设每个观测值的有效长度分别为 \(2\) 和 \(6\)。鉴于此,我们期望输出是一个 \(2 \times 1 \times 4\) 的张量,即小批量中的每个样本占一行。
queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 4))
valid_lens = torch.tensor([2, 6])
attention = DotProductAttention(dropout=0.5)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
queries = np.random.normal(0, 1, (2, 1, 2))
keys = np.random.normal(0, 1, (2, 10, 2))
values = np.random.normal(0, 1, (2, 10, 4))
valid_lens = np.array([2, 6])
attention = DotProductAttention(dropout=0.5)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
queries = jax.random.normal(d2l.get_key(), (2, 1, 2))
keys = jax.random.normal(d2l.get_key(), (2, 10, 2))
values = jax.random.normal(d2l.get_key(), (2, 10, 4))
valid_lens = jnp.array([2, 6])
attention = DotProductAttention(dropout=0.5)
(output, attention_weights), params = attention.init_with_output(
d2l.get_key(), queries, keys, values, valid_lens)
print(output)
[[[ 0.75924027 -0.4776329 0.19306126 0.15036084]]
[[-0.07728005 1.1064801 -0.839485 -0.36051023]]]
queries = tf.random.normal(shape=(2, 1, 2))
keys = tf.random.normal(shape=(2, 10, 2))
values = tf.random.normal(shape=(2, 10, 4))
valid_lens = tf.constant([2, 6])
attention = DotProductAttention(dropout=0.5)
d2l.check_shape(attention(queries, keys, values, valid_lens, training=False),
(2, 1, 4))
让我们检查一下注意力权重是否确实在第二列和第六列之后消失了(因为将有效长度设置为 \(2\) 和 \(6\))。
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(tf.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
11.3.4. 加性注意力¶
当查询 \(\mathbf{q}\) 和键 \(\mathbf{k}\) 是不同维度的向量时,我们可以使用矩阵通过 \(\mathbf{q}^\top \mathbf{M} \mathbf{k}\) 来解决不匹配问题,或者我们可以使用加性注意力作为评分函数。另一个好处是,正如其名,注意力是可加的。这可以带来一些微小的计算节省。给定一个查询 \(\mathbf{q} \in \mathbb{R}^q\) 和一个键 \(\mathbf{k} \in \mathbb{R}^k\),*加性注意力*(additive attention)评分函数 (Bahdanau et al., 2014) 由下式给出:
其中 \(\mathbf W_q\in\mathbb R^{h\times q}\)、\(\mathbf W_k\in\mathbb R^{h\times k}\) 和 \(\mathbf w_v\in\mathbb R^{h}\) 是可学习的参数。然后,该项被送入 softmax 以确保非负性和归一化。对 (11.3.7) 的一个等价解释是,查询和键被连接起来并送入一个具有单个隐藏层的 MLP。我们使用 \(\tanh\) 作为激活函数并禁用偏置项,实现加性注意力如下:
class AdditiveAttention(nn.Module): #@save
"""Additive attention."""
def __init__(self, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.LazyLinear(num_hiddens, bias=False)
self.W_q = nn.LazyLinear(num_hiddens, bias=False)
self.w_v = nn.LazyLinear(1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return torch.bmm(self.dropout(self.attention_weights), values)
class AdditiveAttention(nn.Block): #@save
"""Additive attention."""
def __init__(self, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
# Use flatten=False to only transform the last axis so that the
# shapes for the other axes are kept the same
self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False)
self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False)
self.w_v = nn.Dense(1, use_bias=False, flatten=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1,
# no. of key-value pairs, num_hiddens). Sum them up with
# broadcasting
features = np.expand_dims(queries, axis=2) + np.expand_dims(
keys, axis=1)
features = np.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores:
# (batch_size, no. of queries, no. of key-value pairs)
scores = np.squeeze(self.w_v(features), axis=-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return npx.batch_dot(self.dropout(self.attention_weights), values)
class AdditiveAttention(nn.Module): #@save
num_hiddens: int
dropout: float
def setup(self):
self.W_k = nn.Dense(self.num_hiddens, use_bias=False)
self.W_q = nn.Dense(self.num_hiddens, use_bias=False)
self.w_v = nn.Dense(1, use_bias=False)
@nn.compact
def __call__(self, queries, keys, values, valid_lens, training=False):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = jnp.expand_dims(queries, axis=2) + jnp.expand_dims(keys, axis=1)
features = nn.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = self.w_v(features).squeeze(-1)
attention_weights = masked_softmax(scores, valid_lens)
dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return dropout_layer(attention_weights)@values, attention_weights
class AdditiveAttention(tf.keras.layers.Layer): #@save
"""Additive attention."""
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super().__init__(**kwargs)
self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=False)
self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=False)
self.w_v = tf.keras.layers.Dense(1, use_bias=False)
self.dropout = tf.keras.layers.Dropout(dropout)
def call(self, queries, keys, values, valid_lens, **kwargs):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = tf.expand_dims(queries, axis=2) + tf.expand_dims(
keys, axis=1)
features = tf.nn.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = tf.squeeze(self.w_v(features), axis=-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return tf.matmul(self.dropout(
self.attention_weights, **kwargs), values)
让我们看看 AdditiveAttention
是如何工作的。在我们的玩具示例中,我们选择大小分别为 \((2, 1, 20)\)、\((2, 10, 2)\) 和 \((2, 10, 4)\) 的查询、键和值。这与我们为 DotProductAttention
所做的选择相同,只是现在查询是 \(20\) 维的。同样,我们选择 \((2, 6)\) 作为小批量中序列的有效长度。
queries = torch.normal(0, 1, (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
queries = np.random.normal(0, 1, (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
queries = jax.random.normal(d2l.get_key(), (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
(output, attention_weights), params = attention.init_with_output(
d2l.get_key(), queries, keys, values, valid_lens)
print(output)
[[[ 0.8057054 -0.45312855 0.233752 0.32691044]]
[[-0.23993565 0.23599407 0.04756263 0.13463953]]]
queries = tf.random.normal(shape=(2, 1, 20))
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
dropout=0.1)
d2l.check_shape(attention(queries, keys, values, valid_lens, training=False),
(2, 1, 4))
在回顾注意力函数时,我们看到其行为与 DotProductAttention
在性质上非常相似。也就是说,只有在所选有效长度 \((2, 6)\) 内的项是非零的。
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(tf.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
11.3.5. 总结¶
在本节中,我们介绍了两种关键的注意力评分函数:点积注意力和加性注意力。它们是在可变长度序列上进行聚合的有效工具。特别是,点积注意力是现代 Transformer 架构的中流砥柱。当查询和键是不同长度的向量时,我们可以改用加性注意力评分函数。优化这些层是近年来发展的关键领域之一。例如,NVIDIA的Transformer库 和 Megatron (Shoeybi et al., 2019) 都关键地依赖于注意力机制的高效变体。我们将在后续章节回顾 Transformer 时更深入地探讨这一点。