10.8. 束搜索¶ 在 SageMaker Studio Lab 中打开 Notebook
在 第 10.7 节中,我们介绍了编码器—解码器架构,以及端到端训练它们的标准技术。然而,在测试时预测方面,我们只提到了*贪心*策略,即在每个时间步,我们选择具有最高预测概率的词元作为下一个词元,直到在某个时间步,我们预测出特殊的序列结束“<eos>”词元。在本节中,我们将首先形式化这种*贪心搜索*(greedy search)策略,并指出从业者倾向于遇到的一些问题。随后,我们将此策略与两种替代方案进行比较:*穷举搜索*(exhaustive search)(说明性但并不实用)和*束搜索*(beam search)(实践中的标准方法)。
让我们首先建立我们的数学符号,借鉴 第 10.7 节的约定。在任何时间步 \(t'\),解码器输出预测,表示词汇表中每个词元在序列中下一个出现的概率(\(y_{t'+1}\)的可能值),该概率以先前的词元 \(y_1, \ldots, y_{t'}\)和由编码器生成的代表输入序列的上下文变量 \(\mathbf{c}\)为条件。为了量化计算成本,用 \(\mathcal{Y}\)表示输出词汇表(包括特殊的序列结束词元“<eos>”)。我们还将输出序列的最大词元数指定为 \(T'\)。我们的目标是从所有 \(\mathcal{O}(\left|\mathcal{Y}\right|^{T'})\)个可能的输出序列中搜索一个理想的输出。请注意,这个数字略微高估了不同输出的数量,因为一旦出现“<eos>”词元,就不会有后续的词元了。然而,就我们的目的而言,这个数字大致反映了搜索空间的大小。
10.8.1. 贪心搜索¶
考虑 第 10.7 节中简单的*贪心搜索*策略。在这里,在任何时间步 \(t'\),我们都从 \(\mathcal{Y}\)中简单地选择具有最高条件概率的词元,即
一旦我们的模型输出“<eos>”(或者我们达到最大长度 \(T'\)),输出序列就完成了。
这个策略看起来可能很合理,实际上也不算太差!考虑到它的计算要求不高,你很难用更少的代价获得更多的回报。然而,如果我们暂时不考虑效率,搜索*最可能的序列*,而不是(贪心选择的)*最可能的词元*序列,可能看起来更合理。事实证明,这两者可能大相径庭。最可能的序列是最大化表达式 \(\prod_{t'=1}^{T'} P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c})\)的序列。在我们的机器翻译示例中,如果解码器真正恢复了底层生成过程的概率,那么这将为我们提供最可能的翻译。不幸的是,不能保证贪心搜索会给我们这个序列。
让我们用一个例子来说明。假设输出词典中有四个词元“A”、“B”、“C”和“<eos>”。在 图 10.8.1中,每个时间步下的四个数字分别表示在该时间步生成“A”、“B”、“C”和“<eos>”的条件概率。
图 10.8.1 在每个时间步,贪心搜索选择具有最高条件概率的词元。¶
在每个时间步,贪心搜索选择具有最高条件概率的词元。因此,输出序列“A”、“B”、“C”和“<eos>”将被预测(图 10.8.1)。此输出序列的条件概率是 \(0.5\times0.4\times0.4\times0.6 = 0.048\)。
接下来,让我们看 图 10.8.2中的另一个例子。与 图 10.8.1不同,在时间步 2,我们选择了词元“C”,它具有*第二*高的条件概率。
图 10.8.2 每个时间步下的四个数字表示在该时间步生成“A”、“B”、“C”和“<eos>”的条件概率。在时间步 2,选择了具有第二高条件概率的词元“C”。¶
由于时间步 3 所基于的时间步 1 和 2 的输出子序列已经从 图 10.8.1中的“A”和“B”变为 图 10.8.2中的“A”和“C”,因此时间步 3 每个词元的条件概率在 图 10.8.2中也发生了变化。假设我们在时间步 3 选择词元“B”。现在时间步 4 的条件是前三个时间步的输出子序列“A”、“C”和“B”,这与 图 10.8.1中的“A”、“B”和“C”不同。因此,在 图 10.8.2中时间步 4 生成每个词元的条件概率也与 图 10.8.1中的不同。结果,在 图 10.8.2中输出序列“A”、“C”、“B”和“<eos>”的条件概率是 \(0.5\times0.3 \times0.6\times0.6=0.054\),这比 图 10.8.1中贪心搜索的概率要大。在这个例子中,通过贪心搜索获得的输出序列“A”、“B”、“C”和“<eos>”并非最优。
10.8.2. 穷举搜索¶
如果目标是获得最可能的序列,我们可以考虑使用*穷举搜索*:枚举所有可能的输出序列及其条件概率,然后输出预测概率最高的那个。
虽然这肯定能得到我们想要的结果,但它会带来高得令人望而却步的计算成本 \(\mathcal{O}(\left|\mathcal{Y}\right|^{T'})\),该成本在序列长度上是指数级的,并且基数巨大,由词汇表大小决定。例如,当 \(|\mathcal{Y}|=10000\) 和 \(T'=10\)时(与实际应用中的数字相比,这两个数字都很小),我们将需要评估 \(10000^{10} = 10^{40}\)个序列,这已经超出了任何可预见的计算机的能力。另一方面,贪心搜索的计算成本是 \(\mathcal{O}(\left|\mathcal{Y}\right|T')\):惊人地便宜但远非最优。例如,当 \(|\mathcal{Y}|=10000\) 和 \(T'=10\)时,我们只需要评估 \(10000\times10=10^5\)个序列。
10.8.3. 束搜索¶
你可以将序列解码策略看作一个谱系,*束搜索*在贪心搜索的效率和穷举搜索的最优性之间取得了折中。束搜索最直接的版本由一个超参数来表征,即*束宽*(beam size) \(k\)。让我们来解释一下这个术语。在时间步 1,我们选择具有最高预测概率的 \(k\)个词元。它们中的每一个将分别成为 \(k\)个候选输出序列的第一个词元。在每个后续的时间步,基于前一个时间步的 \(k\)个候选输出序列,我们继续从 \(k\left|\mathcal{Y}\right|\)个可能的选择中选出 \(k\)个具有最高预测概率的候选输出序列。
图 10.8.3 束搜索的过程(束宽 \(=2\);输出序列的最大长度 \(=3\))。候选输出序列是 \(\mathit{A}\)、\(\mathit{C}\)、\(\mathit{AB}\)、\(\mathit{CE}\)、\(\mathit{ABD}\) 和 \(\mathit{CED}\)。¶
图 10.8.3用一个例子演示了束搜索的过程。假设输出词汇表只包含五个元素:\(\mathcal{Y} = \{A, B, C, D, E\}\),其中一个是“<eos>”。设束宽为 2,输出序列的最大长度为 3。在时间步 1,假设具有最高条件概率 \(P(y_1 \mid \mathbf{c})\)的词元是 \(A\)和 \(C\)。在时间步 2,对于所有 \(y_2 \in \mathcal{Y},\)我们计算
并从这十个值中选出最大的两个,比如说 \(P(A, B \mid \mathbf{c})\) 和 \(P(C, E \mid \mathbf{c})\)。然后在时间步 3,对于所有 \(y_3 \in \mathcal{Y}\),我们计算
并从这十个值中选出最大的两个,比如说 \(P(A, B, D \mid \mathbf{c})\) 和 \(P(C, E, D \mid \mathbf{c})\)。结果,我们得到六个候选输出序列:(i) \(A\);(ii) \(C\);(iii) \(A\), \(B\);(iv) \(C\), \(E\);(v) \(A\), \(B\), \(D\);和 (vi) \(C\), \(E\), \(D\)。
最后,我们基于这六个序列获得最终候选输出序列的集合(例如,丢弃包括“<eos>”及其之后的部分)。然后我们选择使以下分数最大化的输出序列
这里 \(L\)是最终候选序列的长度,而 \(\alpha\)通常设为 0.75。由于较长的序列在 (10.8.4)的求和中有更多的对数项,分母中的项 \(L^\alpha\)会对长序列进行惩罚。
束搜索的计算成本为 \(\mathcal{O}(k\left|\mathcal{Y}\right|T')\)。这个结果介于贪心搜索和穷举搜索之间。贪心搜索可以被看作是束宽设置为 1 时的束搜索特例。
10.8.4. 小结¶
序列搜索策略包括贪心搜索、穷举搜索和束搜索。束搜索通过灵活选择束宽,在准确性和计算成本之间进行权衡。