11.1. 查询、键和值
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

到目前为止,我们所回顾的所有网络都严重依赖于大小固定的输入。例如,ImageNet 中的图像大小为 \(224 \times 224\) 像素,CNN 专门为此大小调整。即使在自然语言处理中,RNN 的输入大小也是定义明确且固定的。可变大小的问题可以通过一次处理一个词元,或通过专门设计的卷积核 (Kalchbrenner et al., 2014) 来解决。当输入是真正可变大小、信息内容也各不相同时,这种方法可能会导致严重问题,例如在 第 10.7 节 中文本的转换 (Sutskever et al., 2014)。特别是对于长序列,要跟踪网络已经生成或甚至看到的所有内容变得相当困难。即使是像 Yang et al. (2016) 提出的显式跟踪启发式方法,也只能提供有限的好处。

将此与数据库进行比较。在最简单的形式中,它们是键(\(k\))和值(\(v\))的集合。例如,我们的数据库 \(\mathcal{D}\) 可能包含元组 {("Zhang", "Aston"), ("Lipton", "Zachary"), ("Li", "Mu"), ("Smola", "Alex"), ("Hu", "Rachel"), ("Werness", "Brent")},其中姓是键,名是值。我们可以对 \(\mathcal{D}\) 进行操作,例如使用精确查询(\(q\))"Li",这将返回值 "Mu"。如果 ("Li", "Mu") 不在 \(\mathcal{D}\) 中,则没有有效答案。如果我们也允许近似匹配,我们将检索到 ("Lipton", "Zachary")。这个相当简单和普通的例子却教会了我们一些有用的东西:

  • 我们可以设计查询 \(q\),使其对 (\(k\), \(v\)) 对进行操作,并且无论数据库大小如何都有效。

  • 同一个查询可以根据数据库的内容得到不同的答案。

  • 用于操作大型状态空间(数据库)的“代码”可以非常简单(例如,精确匹配、近似匹配、top-\(k\))。

  • 无需压缩或简化数据库以使操作有效。

很明显,如果我们不是为了解释深度学习,就不会在这里引入一个简单的数据库。事实上,这引出了过去十年深度学习中引入的最令人兴奋的概念之一:注意力机制 (Bahdanau et al., 2014)。我们稍后会介绍其在机器翻译中的具体应用。现在,只需考虑以下内容:用 \(\mathcal{D} \stackrel{\textrm{def}}{=} \{(\mathbf{k}_1, \mathbf{v}_1), \ldots (\mathbf{k}_m, \mathbf{v}_m)\}\) 表示一个包含 \(m\)元组的数据库。此外,用 \(\mathbf{q}\) 表示一个查询。然后我们可以将对 \(\mathcal{D}\)注意力定义为

(11.1.1)\[\textrm{Attention}(\mathbf{q}, \mathcal{D}) \stackrel{\textrm{def}}{=} \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i,\]

其中 \(\alpha(\mathbf{q}, \mathbf{k}_i) \in \mathbb{R}\) (\(i = 1, \ldots, m\)) 是标量注意力权重。该操作通常被称为注意力汇聚注意力这个名字来源于该操作特别关注权重 \(\alpha\) 显著(即大)的项。因此,对 \(\mathcal{D}\) 的注意力会生成数据库中值的线性组合。实际上,这包含了上述例子作为一个特例,其中除了一个权重为零外,其余都为零。我们有一些特殊情况:

  • 权重 \(\alpha(\mathbf{q}, \mathbf{k}_i)\) 是非负的。在这种情况下,注意力机制的输出包含在由值 \(\mathbf{v}_i\) 张成的凸锥中。

  • 权重 \(\alpha(\mathbf{q}, \mathbf{k}_i)\) 形成凸组合,即 \(\sum_i \alpha(\mathbf{q}, \mathbf{k}_i) = 1\) 且对于所有 \(i\) 都有 \(\alpha(\mathbf{q}, \mathbf{k}_i) \geq 0\)。这是深度学习中最常见的设置。

  • 权重 \(\alpha(\mathbf{q}, \mathbf{k}_i)\) 中恰好有一个是 \(1\),而所有其他都是 \(0\)。这类似于传统的数据库查询。

  • 所有权重都相等,即对于所有 \(i\) 都有 \(\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{1}{m}\)。这相当于对整个数据库进行平均,在深度学习中也称为平均汇聚。

确保权重总和为 \(1\) 的常用策略是通过以下方式对其进行归一化:

(11.1.2)\[\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\alpha(\mathbf{q}, \mathbf{k}_i)}{{\sum_j} \alpha(\mathbf{q}, \mathbf{k}_j)}.\]

特别是,为了确保权重也是非负的,可以采用指数化。这意味着我们现在可以选择任何函数 \(a(\mathbf{q}, \mathbf{k})\),然后通过以下方式对其应用多项式模型中使用的 softmax 操作:

(11.1.3)\[\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_j \exp(a(\mathbf{q}, \mathbf{k}_j))}.\]

这个操作在所有深度学习框架中都随时可用。它是可微的,并且其梯度永远不会消失,这些都是模型所期望的属性。但请注意,上面介绍的注意力机制不是唯一的选择。例如,我们可以设计一个不可微的注意力模型,可以使用强化学习方法进行训练 (Mnih et al., 2014)。正如人们所预期的,训练这样的模型相当复杂。因此,现代注意力研究的主体遵循 图 11.1.1 中概述的框架。因此,我们着重阐述这一系列可微机制。

../_images/qkv.svg

图 11.1.1 注意力机制通过注意力汇聚计算值 \(\mathbf{v}_\mathit{i}\) 的线性组合,其中权重是根据查询 \(\mathbf{q}\) 和键 \(\mathbf{k}_\mathit{i}\) 之间的兼容性得出的。

非常值得注意的是,用于在键和值集合上执行的实际“代码”,即查询,可以非常简洁,尽管要操作的空间很大。这对于网络层来说是一个理想的属性,因为它不需要太多的参数来学习。同样方便的是,注意力可以在任意大的数据库上操作,而无需改变注意力汇聚操作的执行方式。

import torch
from d2l import torch as d2l
from mxnet import np, npx
from d2l import mxnet as d2l

npx.set_np()
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l

11.1.1. 可视化

注意力机制的好处之一是它可以非常直观,特别是当权重是非负且总和为 \(1\) 时。在这种情况下,我们可能会将大权重解释为模型选择相关组件的一种方式。虽然这是一个很好的直觉,但重要的是要记住,它仅仅是直觉。无论如何,我们可能希望在应用各种不同的查询时,将其对给定键集的影响可视化。这个函数将在后面派上用场。

因此,我们定义 show_heatmaps 函数。请注意,它不接受矩阵(注意力权重)作为输入,而是接受一个具有四个轴的张量,允许一组不同的查询和权重。因此,输入 matrices 的形状是(用于显示的行数,用于显示的列数,查询数,键数)。这在稍后我们想要可视化设计 Transformer 的工作原理时会很有用。

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """Show heatmaps of matrices."""
    d2l.use_svg_display()
    num_rows, num_cols, _, _ = matrices.shape
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);
#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """Show heatmaps of matrices."""
    d2l.use_svg_display()
    num_rows, num_cols, _, _ = matrices.shape
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.asnumpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);
#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """Show heatmaps of matrices."""
    d2l.use_svg_display()
    num_rows, num_cols, _, _ = matrices.shape
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix, cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);
#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """Show heatmaps of matrices."""
    d2l.use_svg_display()
    num_rows, num_cols, _, _ = matrices.shape
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);

作为快速健全性检查,我们可视化单位矩阵,它代表了只有当查询和键相同时,注意力权重才为 \(1\) 的情况。

attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')
../_images/output_queries-keys-values_7fe0e8_33_0.svg
attention_weights = np.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')
[21:50:08] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
../_images/output_queries-keys-values_7fe0e8_36_1.svg
attention_weights = jnp.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')
../_images/output_queries-keys-values_7fe0e8_39_0.svg
attention_weights = tf.reshape(tf.eye(10), (1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')
../_images/output_queries-keys-values_7fe0e8_42_0.svg

11.1.2. 小结

注意力机制使我们能够聚合来自许多(键,值)对的数据。到目前为止,我们的讨论还相当抽象,只是描述了一种汇聚数据的方式。我们还没有解释那些神秘的查询、键和值可能从何而来。一些直觉可能在这里有所帮助:例如,在回归设置中,查询可能对应于应该执行回归的位置。键是观察到过去数据的位置,而值是(回归)值本身。这就是所谓的 Nadaraya-Watson 估计器 (Nadaraya, 1964, Watson, 1964),我们将在下一节中研究它。

从设计上讲,注意力机制提供了一种可微分的控制手段,神经网络可以通过它从一个集合中选择元素,并构建相关的表示加权和。

11.1.3. 练习

  1. 假设您想重新实现经典数据库中使用的近似(键,查询)匹配,您会选择哪种注意力函数?

  2. 假设注意力函数由 \(a(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i\) 给出,并且对于 \(i = 1, \ldots, m\)\(\mathbf{k}_i = \mathbf{v}_i\)。用 \(p(\mathbf{k}_i; \mathbf{q})\) 表示在使用 (11.1.3) 中的 softmax 归一化时键上的概率分布。证明 \(\nabla_{\mathbf{q}} \mathop{\textrm{Attention}}(\mathbf{q}, \mathcal{D}) = \textrm{Cov}_{p(\mathbf{k}_i; \mathbf{q})}[\mathbf{k}_i]\)

  3. 使用注意力机制设计一个可微分的搜索引擎。

  4. 回顾 Squeeze-and-Excitation Networks (Hu et al., 2018) 的设计,并从注意力机制的角度对其进行解释。