15.6. 子词嵌入
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

在英语中,“helps”“helped”和“helping”是同一单词“help”的不同屈折形式。单词“dog”和“dogs”之间的关系与“cat”和“cats”之间的关系相同,“boy”和“boyfriend”之间的关系与“girl”和“girlfriend”之间的关系相同。在法语和西班牙语等其他语言中,许多动词有40多种屈折形式,而在芬兰语中,一个名词最多可以有15种格。在语言学中,形态学研究单词的构成和单词之间的关系。然而,word2vec和GloVe都没有探讨单词的内部结构。

15.6.1. fastText模型

回想一下word2vec中单词是如何表示的。在跳元模型和连续词袋模型中,同一单词的不同屈折形式直接由不同的向量表示,它们之间没有共享参数。为了使用形态学信息,fastText模型提出了一种子词嵌入方法,其中子词是字符\(n\)-gram (Bojanowski et al., 2017)。fastText可以看作是子词级的跳元模型,其中每个中心词由其子词向量之和表示,而不是学习词级向量表示。

让我们以单词“where”为例,说明如何在fastText中为每个中心词获取子词。首先,在单词的开头和结尾添加特殊字符“<”和“>”,以便将前缀和后缀与其他子词区分开。然后,从单词中提取字符\(n\)-gram。例如,当\(n=3\)时,我们获得所有长度为3的子词:“<wh”“whe”“her”“ere”“re>”,以及特殊的子词“<where>”。

在fastText中,对于任意词\(w\),用\(\mathcal{G}_w\)表示其所有长度在3和6之间的子词及其特殊子词的并集。词表是所有单词的子词的并集。设\(\mathbf{z}_g\)是词典中子词\(g\)的向量,那么在跳元模型中,作为中心词的词\(w\)的向量\(\mathbf{v}_w\)是其子词向量的和:

(15.6.1)\[\mathbf{v}_w = \sum_{g\in\mathcal{G}_w} \mathbf{z}_g.\]

fastText的其余部分与跳元模型相同。与跳元模型相比,fastText中的词表更大,导致模型参数更多。此外,为了计算一个词的表示,必须对其所有子词向量求和,从而导致更高的计算复杂度。然而,由于具有相似结构的词之间共享来自子词的参数,罕见词甚至词表外的词在fastText中可能会获得更好的向量表示。

15.6.2. 字节对编码

在fastText中,所有提取的子词都必须具有指定的长度,例如\(3\)\(6\),因此无法预定义词表大小。为了在固定大小的词表中允许可变长度的子词,我们可以应用一种称为字节对编码(BPE)的压缩算法来提取子词 (Sennrich et al., 2015)

字节对编码对训练数据集进行统计分析,以发现单词内的公共符号,例如任意长度的连续字符。从长度为1的符号开始,字节对编码迭代地合并最频繁的连续符号对以产生新的更长的符号。请注意,为了提高效率,不考虑跨越单词边界的符号对。最后,我们可以使用这些符号作为子词来切分单词。字节对编码及其变体已用于流行的自然语言处理预训练模型的输入表示,如GPT-2 (Radford et al., 2019)和RoBERTa (Liu et al., 2019)。下面,我们将说明字节对编码的工作原理。

首先,我们将符号词表初始化为所有英文小写字符、一个特殊的词尾符号'_'和一个特殊的未知符号'[UNK]'

import collections

symbols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
           'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
           '_', '[UNK]']
import collections

symbols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
           'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
           '_', '[UNK]']

由于我们不考虑跨越单词边界的符号对,我们只需要一个字典raw_token_freqs,它将单词映射到它们在数据集中的频率(出现次数)。请注意,特殊符号'_'附加到每个单词的末尾,以便我们可以从输出符号序列(例如,“a_ tall er_ man”)中轻松地恢复单词序列(例如,“a taller man”)。由于我们从仅包含单个字符和特殊符号的词表开始合并过程,因此在每个单词(字典token_freqs的键)内的每对连续字符之间插入空格。换句话说,空格是单词内符号之间的分隔符。

raw_token_freqs = {'fast_': 4, 'faster_': 3, 'tall_': 5, 'taller_': 4}
token_freqs = {}
for token, freq in raw_token_freqs.items():
    token_freqs[' '.join(list(token))] = raw_token_freqs[token]
token_freqs
{'f a s t _': 4, 'f a s t e r _': 3, 't a l l _': 5, 't a l l e r _': 4}
raw_token_freqs = {'fast_': 4, 'faster_': 3, 'tall_': 5, 'taller_': 4}
token_freqs = {}
for token, freq in raw_token_freqs.items():
    token_freqs[' '.join(list(token))] = raw_token_freqs[token]
token_freqs
{'f a s t _': 4, 'f a s t e r _': 3, 't a l l _': 5, 't a l l e r _': 4}

我们定义以下get_max_freq_pair函数,该函数返回一个单词内最频繁的连续符号对,其中单词来自输入字典token_freqs的键。

def get_max_freq_pair(token_freqs):
    pairs = collections.defaultdict(int)
    for token, freq in token_freqs.items():
        symbols = token.split()
        for i in range(len(symbols) - 1):
            # Key of `pairs` is a tuple of two consecutive symbols
            pairs[symbols[i], symbols[i + 1]] += freq
    return max(pairs, key=pairs.get)  # Key of `pairs` with the max value
def get_max_freq_pair(token_freqs):
    pairs = collections.defaultdict(int)
    for token, freq in token_freqs.items():
        symbols = token.split()
        for i in range(len(symbols) - 1):
            # Key of `pairs` is a tuple of two consecutive symbols
            pairs[symbols[i], symbols[i + 1]] += freq
    return max(pairs, key=pairs.get)  # Key of `pairs` with the max value

作为一种基于连续符号频率的贪心方法,字节对编码将使用以下merge_symbols函数来合并最频繁的连续符号对以产生新的符号。

def merge_symbols(max_freq_pair, token_freqs, symbols):
    symbols.append(''.join(max_freq_pair))
    new_token_freqs = dict()
    for token, freq in token_freqs.items():
        new_token = token.replace(' '.join(max_freq_pair),
                                  ''.join(max_freq_pair))
        new_token_freqs[new_token] = token_freqs[token]
    return new_token_freqs
def merge_symbols(max_freq_pair, token_freqs, symbols):
    symbols.append(''.join(max_freq_pair))
    new_token_freqs = dict()
    for token, freq in token_freqs.items():
        new_token = token.replace(' '.join(max_freq_pair),
                                  ''.join(max_freq_pair))
        new_token_freqs[new_token] = token_freqs[token]
    return new_token_freqs

现在,我们对字典token_freqs的键迭代地执行字节对编码算法。在第一次迭代中,最频繁的连续符号对是't''a',因此字节对编码将它们合并以产生一个新符号'ta'。在第二次迭代中,字节对编码继续合并'ta''l'以产生另一个新符号'tal'

num_merges = 10
for i in range(num_merges):
    max_freq_pair = get_max_freq_pair(token_freqs)
    token_freqs = merge_symbols(max_freq_pair, token_freqs, symbols)
    print(f'merge #{i + 1}:', max_freq_pair)
merge #1: ('t', 'a')
merge #2: ('ta', 'l')
merge #3: ('tal', 'l')
merge #4: ('f', 'a')
merge #5: ('fa', 's')
merge #6: ('fas', 't')
merge #7: ('e', 'r')
merge #8: ('er', '_')
merge #9: ('tall', '_')
merge #10: ('fast', '_')
num_merges = 10
for i in range(num_merges):
    max_freq_pair = get_max_freq_pair(token_freqs)
    token_freqs = merge_symbols(max_freq_pair, token_freqs, symbols)
    print(f'merge #{i + 1}:', max_freq_pair)
merge #1: ('t', 'a')
merge #2: ('ta', 'l')
merge #3: ('tal', 'l')
merge #4: ('f', 'a')
merge #5: ('fa', 's')
merge #6: ('fas', 't')
merge #7: ('e', 'r')
merge #8: ('er', '_')
merge #9: ('tall', '_')
merge #10: ('fast', '_')

在字节对编码的10次迭代之后,我们可以看到列表symbols现在包含10个从其他符号迭代合并而来的符号。

print(symbols)
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]', 'ta', 'tal', 'tall', 'fa', 'fas', 'fast', 'er', 'er_', 'tall_', 'fast_']
print(symbols)
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]', 'ta', 'tal', 'tall', 'fa', 'fas', 'fast', 'er', 'er_', 'tall_', 'fast_']

对于字典raw_token_freqs的键中指定的相同数据集,数据集中的每个单词现在都被字节对编码算法产生的子词“fast_”、“fast”、“er_”、“tall_”和“tall”分割。例如,单词“faster_”和“taller_”分别被分割为“fast er_”和“tall er_”。

print(list(token_freqs.keys()))
['fast_', 'fast er_', 'tall_', 'tall er_']
print(list(token_freqs.keys()))
['fast_', 'fast er_', 'tall_', 'tall er_']

请注意,字节对编码的结果取决于所使用的数据集。我们还可以使用从一个数据集学习的子词来分割另一个数据集的单词。作为一种贪心方法,以下segment_BPE函数尝试从输入参数symbols中将单词分解为最长的可能子词。

def segment_BPE(tokens, symbols):
    outputs = []
    for token in tokens:
        start, end = 0, len(token)
        cur_output = []
        # Segment token with the longest possible subwords from symbols
        while start < len(token) and start < end:
            if token[start: end] in symbols:
                cur_output.append(token[start: end])
                start = end
                end = len(token)
            else:
                end -= 1
        if start < len(token):
            cur_output.append('[UNK]')
        outputs.append(' '.join(cur_output))
    return outputs
def segment_BPE(tokens, symbols):
    outputs = []
    for token in tokens:
        start, end = 0, len(token)
        cur_output = []
        # Segment token with the longest possible subwords from symbols
        while start < len(token) and start < end:
            if token[start: end] in symbols:
                cur_output.append(token[start: end])
                start = end
                end = len(token)
            else:
                end -= 1
        if start < len(token):
            cur_output.append('[UNK]')
        outputs.append(' '.join(cur_output))
    return outputs

在下文中,我们使用列表symbols中的子词(从上述数据集学习得到)来分割代表另一个数据集的tokens

tokens = ['tallest_', 'fatter_']
print(segment_BPE(tokens, symbols))
['tall e s t _', 'fa t t er_']
tokens = ['tallest_', 'fatter_']
print(segment_BPE(tokens, symbols))
['tall e s t _', 'fa t t er_']

15.6.3. 小结

  • fastText模型提出了一种子词嵌入方法。它基于word2vec中的跳元模型,将中心词表示为其子词向量的和。

  • 字节对编码对训练数据集进行统计分析,以发现单词内的常见符号。作为一种贪心方法,字节对编码迭代地合并最频繁的连续符号对。

  • 子词嵌入可以提高罕见词和词表外词的表示质量。

15.6.4. 练习

  1. 举个例子,在英语中大约有\(3\times 10^8\)个可能的\(6\)-gram。当子词太多时会出现什么问题?如何解决这个问题?提示:请参阅fastText论文 (Bojanowski et al., 2017) 3.2节的末尾。

  2. 如何设计一个基于连续词袋模型的子词嵌入模型?

  3. 为了得到一个大小为\(m\)的词表,当初始符号词表大小为\(n\)时,需要多少次合并操作?

  4. 如何扩展字节对编码的思想来提取短语?