16.5. 自然语言推断:使用注意力¶ 在 SageMaker Studio Lab 中打开 Notebook
我们在 第 16.4 节中介绍了自然语言推断任务和SNLI数据集。 鉴于许多模型都基于复杂和深度的架构, Parikh等人(2016) 提出用注意力机制来解决自然语言推断问题,并称之为“可分解注意力模型”。 这使得模型无需循环层或卷积层,在SNLI数据集上用少得多的参数取得了当时最好的结果。 在本节中,我们将描述和实现这种基于注意力的自然语言推断方法(使用多层感知机),如 图 16.5.1所示。
图 16.5.1 本节将预训练的GloVe输入到基于注意力和多层感知机的架构中进行自然语言推断。¶
16.5.1. 模型¶
与保留前提和假设中的词元顺序相比,我们可以只将一个文本序列中的词元与另一个文本序列中的每个词元对齐,反之亦然,然后比较和聚合这些信息来预测前提和假设之间的逻辑关系。 类似于机器翻译中源句和目标句之间的词元对齐,前提和假设之间的词元对齐可以通过注意力机制很好地完成。
图 16.5.2 使用注意力机制的自然语言推断。¶
图 16.5.2 描述了使用注意力机制的自然语言推断方法。从高层次上讲,它由三个联合训练的步骤组成:关注、比较和聚合。我们将在下文中一步步地说明它们。
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
from mxnet import gluon, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
16.5.1.1. 关注¶
第一步是将一个文本序列中的词元与另一个序列中的每个词元对齐。假设前提是“i do need sleep”,假设是“i am tired”。由于语义上的相似性,我们可能希望将假设中的“i”与前提中的“i”对齐,并将假设中的“tired”与前提中的“sleep”对齐。同样,我们可能希望将前提中的“i”与假设中的“i”对齐,并将前提中的“need”和“sleep”与假设中的“tired”对齐。请注意,这种对齐是使用加权平均的*软*对齐,理想情况下,与要对齐的词元相关联的权重很大。为了便于演示,图 16.5.2以*硬*的方式显示了这种对齐。
现在我们更详细地描述使用注意力机制的软对齐。假设前提和假设分别是 \(\mathbf{A} = (\mathbf{a}_1, \ldots, \mathbf{a}_m)\) 和 \(\mathbf{B} = (\mathbf{b}_1, \ldots, \mathbf{b}_n)\),它们的词元数分别为 \(m\) 和 \(n\),其中 \(\mathbf{a}_i, \mathbf{b}_j \in \mathbb{R}^{d}\) (\(i = 1, \ldots, m, j = 1, \ldots, n\)) 是一个 \(d\) 维的词向量。对于软对齐,我们计算注意力权重 \(e_{ij} \in \mathbb{R}\) 如下
其中函数 \(f\) 是一个在以下 mlp
函数中定义的多层感知机。 \(f\) 的输出维度由 mlp
的 num_hiddens
参数指定。
def mlp(num_inputs, num_hiddens, flatten):
net = []
net.append(nn.Dropout(0.2))
net.append(nn.Linear(num_inputs, num_hiddens))
net.append(nn.ReLU())
if flatten:
net.append(nn.Flatten(start_dim=1))
net.append(nn.Dropout(0.2))
net.append(nn.Linear(num_hiddens, num_hiddens))
net.append(nn.ReLU())
if flatten:
net.append(nn.Flatten(start_dim=1))
return nn.Sequential(*net)
def mlp(num_hiddens, flatten):
net = nn.Sequential()
net.add(nn.Dropout(0.2))
net.add(nn.Dense(num_hiddens, activation='relu', flatten=flatten))
net.add(nn.Dropout(0.2))
net.add(nn.Dense(num_hiddens, activation='relu', flatten=flatten))
return net
需要强调的是,在 (16.5.1) 中,\(f\) 分别接收输入 \(\mathbf{a}_i\) 和 \(\mathbf{b}_j\),而不是将它们成对地作为输入。这种*分解*技巧导致 \(f\) 的应用次数仅为 \(m + n\)(线性复杂度),而不是 \(mn\)(二次复杂度)。
对 (16.5.1) 中的注意力权重进行归一化,我们计算假设中所有词元向量的加权平均,以获得与前提中索引为 \(i\) 的词元软对齐的假设表示:
同样地,我们为假设中每个索引为 \(j\) 的词元计算前提词元的软对齐:
下面我们定义 Attend
类来计算假设(beta
)与输入前提 A
的软对齐,以及前提(alpha
)与输入假设 B
的软对齐。
class Attend(nn.Module):
def __init__(self, num_inputs, num_hiddens, **kwargs):
super(Attend, self).__init__(**kwargs)
self.f = mlp(num_inputs, num_hiddens, flatten=False)
def forward(self, A, B):
# Shape of `A`/`B`: (`batch_size`, no. of tokens in sequence A/B,
# `embed_size`)
# Shape of `f_A`/`f_B`: (`batch_size`, no. of tokens in sequence A/B,
# `num_hiddens`)
f_A = self.f(A)
f_B = self.f(B)
# Shape of `e`: (`batch_size`, no. of tokens in sequence A,
# no. of tokens in sequence B)
e = torch.bmm(f_A, f_B.permute(0, 2, 1))
# Shape of `beta`: (`batch_size`, no. of tokens in sequence A,
# `embed_size`), where sequence B is softly aligned with each token
# (axis 1 of `beta`) in sequence A
beta = torch.bmm(F.softmax(e, dim=-1), B)
# Shape of `alpha`: (`batch_size`, no. of tokens in sequence B,
# `embed_size`), where sequence A is softly aligned with each token
# (axis 1 of `alpha`) in sequence B
alpha = torch.bmm(F.softmax(e.permute(0, 2, 1), dim=-1), A)
return beta, alpha
class Attend(nn.Block):
def __init__(self, num_hiddens, **kwargs):
super(Attend, self).__init__(**kwargs)
self.f = mlp(num_hiddens=num_hiddens, flatten=False)
def forward(self, A, B):
# Shape of `A`/`B`: (b`atch_size`, no. of tokens in sequence A/B,
# `embed_size`)
# Shape of `f_A`/`f_B`: (`batch_size`, no. of tokens in sequence A/B,
# `num_hiddens`)
f_A = self.f(A)
f_B = self.f(B)
# Shape of `e`: (`batch_size`, no. of tokens in sequence A,
# no. of tokens in sequence B)
e = npx.batch_dot(f_A, f_B, transpose_b=True)
# Shape of `beta`: (`batch_size`, no. of tokens in sequence A,
# `embed_size`), where sequence B is softly aligned with each token
# (axis 1 of `beta`) in sequence A
beta = npx.batch_dot(npx.softmax(e), B)
# Shape of `alpha`: (`batch_size`, no. of tokens in sequence B,
# `embed_size`), where sequence A is softly aligned with each token
# (axis 1 of `alpha`) in sequence B
alpha = npx.batch_dot(npx.softmax(e.transpose(0, 2, 1)), A)
return beta, alpha
16.5.1.2. 比较¶
在下一步中,我们将一个序列中的词元与另一个与之软对齐的序列进行比较。请注意,在软对齐中,一个序列中的所有词元(尽管可能有不同的注意力权重)都将与另一个序列中的一个词元进行比较。为便于演示,图 16.5.2 以*硬*的方式将词元与对齐的词元配对。例如,假设关注步骤确定前提中的“need”和“sleep”都与假设中的“tired”对齐,那么将比较“tired–need sleep”这对。
在比较步骤中,我们将来自一个序列的词元和来自另一个序列的对齐词元的拼接(运算符 \([\cdot, \cdot]\))输入到函数 \(g\)(一个多层感知机)中:
在 (16.5.4) 中,\(\mathbf{v}_{A,i}\) 是前提中词元 \(i\) 与所有与之软对齐的假设词元之间的比较;而 \(\mathbf{v}_{B,j}\) 是假设中词元 \(j\) 与所有与之软对齐的前提词元之间的比较。下面的 Compare
类定义了这样的比较步骤。
class Compare(nn.Module):
def __init__(self, num_inputs, num_hiddens, **kwargs):
super(Compare, self).__init__(**kwargs)
self.g = mlp(num_inputs, num_hiddens, flatten=False)
def forward(self, A, B, beta, alpha):
V_A = self.g(torch.cat([A, beta], dim=2))
V_B = self.g(torch.cat([B, alpha], dim=2))
return V_A, V_B
class Compare(nn.Block):
def __init__(self, num_hiddens, **kwargs):
super(Compare, self).__init__(**kwargs)
self.g = mlp(num_hiddens=num_hiddens, flatten=False)
def forward(self, A, B, beta, alpha):
V_A = self.g(np.concatenate([A, beta], axis=2))
V_B = self.g(np.concatenate([B, alpha], axis=2))
return V_A, V_B
16.5.1.3. 聚合¶
有了两组比较向量 \(\mathbf{v}_{A,i}\) (\(i = 1, \ldots, m\)) 和 \(\mathbf{v}_{B,j}\) (\(j = 1, \ldots, n\)),在最后一步,我们将聚合这些信息来推断逻辑关系。我们首先对两组向量求和:
接下来,我们将两个汇总结果的拼接输入到函数 \(h\)(一个多层感知机)中,以获得逻辑关系的分类结果:
聚合步骤在下面的 Aggregate
类中定义。
class Aggregate(nn.Module):
def __init__(self, num_inputs, num_hiddens, num_outputs, **kwargs):
super(Aggregate, self).__init__(**kwargs)
self.h = mlp(num_inputs, num_hiddens, flatten=True)
self.linear = nn.Linear(num_hiddens, num_outputs)
def forward(self, V_A, V_B):
# Sum up both sets of comparison vectors
V_A = V_A.sum(dim=1)
V_B = V_B.sum(dim=1)
# Feed the concatenation of both summarization results into an MLP
Y_hat = self.linear(self.h(torch.cat([V_A, V_B], dim=1)))
return Y_hat
class Aggregate(nn.Block):
def __init__(self, num_hiddens, num_outputs, **kwargs):
super(Aggregate, self).__init__(**kwargs)
self.h = mlp(num_hiddens=num_hiddens, flatten=True)
self.h.add(nn.Dense(num_outputs))
def forward(self, V_A, V_B):
# Sum up both sets of comparison vectors
V_A = V_A.sum(axis=1)
V_B = V_B.sum(axis=1)
# Feed the concatenation of both summarization results into an MLP
Y_hat = self.h(np.concatenate([V_A, V_B], axis=1))
return Y_hat
16.5.1.4. 整合¶
通过将关注、比较和聚合步骤整合在一起,我们定义了可分解注意力模型来联合训练这三个步骤。
class DecomposableAttention(nn.Module):
def __init__(self, vocab, embed_size, num_hiddens, num_inputs_attend=100,
num_inputs_compare=200, num_inputs_agg=400, **kwargs):
super(DecomposableAttention, self).__init__(**kwargs)
self.embedding = nn.Embedding(len(vocab), embed_size)
self.attend = Attend(num_inputs_attend, num_hiddens)
self.compare = Compare(num_inputs_compare, num_hiddens)
# There are 3 possible outputs: entailment, contradiction, and neutral
self.aggregate = Aggregate(num_inputs_agg, num_hiddens, num_outputs=3)
def forward(self, X):
premises, hypotheses = X
A = self.embedding(premises)
B = self.embedding(hypotheses)
beta, alpha = self.attend(A, B)
V_A, V_B = self.compare(A, B, beta, alpha)
Y_hat = self.aggregate(V_A, V_B)
return Y_hat
class DecomposableAttention(nn.Block):
def __init__(self, vocab, embed_size, num_hiddens, **kwargs):
super(DecomposableAttention, self).__init__(**kwargs)
self.embedding = nn.Embedding(len(vocab), embed_size)
self.attend = Attend(num_hiddens)
self.compare = Compare(num_hiddens)
# There are 3 possible outputs: entailment, contradiction, and neutral
self.aggregate = Aggregate(num_hiddens, 3)
def forward(self, X):
premises, hypotheses = X
A = self.embedding(premises)
B = self.embedding(hypotheses)
beta, alpha = self.attend(A, B)
V_A, V_B = self.compare(A, B, beta, alpha)
Y_hat = self.aggregate(V_A, V_B)
return Y_hat
16.5.2. 训练和评估模型¶
现在,我们将在SNLI数据集上训练和评估定义的可分解注意力模型。我们从读取数据集开始。
16.5.2.1. 读取数据集¶
我们使用 第 16.4 节 中定义的函数下载并读取SNLI数据集。批量大小和序列长度分别设置为 \(256\) 和 \(50\)。
batch_size, num_steps = 256, 50
train_iter, test_iter, vocab = d2l.load_data_snli(batch_size, num_steps)
Downloading ../data/snli_1.0.zip from https://nlp.stanford.edu/projects/snli/snli_1.0.zip...
read 549367 examples
read 9824 examples
batch_size, num_steps = 256, 50
train_iter, test_iter, vocab = d2l.load_data_snli(batch_size, num_steps)
Downloading ../data/snli_1.0.zip from https://nlp.stanford.edu/projects/snli/snli_1.0.zip...
[21:49:40] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
read 549367 examples
read 9824 examples
16.5.2.2. 创建模型¶
我们使用预训练的100维GloVe嵌入来表示输入词元。因此,我们将 (16.5.1) 中向量 \(\mathbf{a}_i\) 和 \(\mathbf{b}_j\) 的维度预定义为100。函数 \(f\) 在 (16.5.1) 中和 \(g\) 在 (16.5.4) 中的输出维度设置为200。然后我们创建一个模型实例,初始化其参数,并加载GloVe嵌入来初始化输入词元的向量。
embed_size, num_hiddens, devices = 100, 200, d2l.try_all_gpus()
net = DecomposableAttention(vocab, embed_size, num_hiddens)
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.data.copy_(embeds);
Downloading ../data/glove.6B.100d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.100d.zip...
embed_size, num_hiddens, devices = 100, 200, d2l.try_all_gpus()
net = DecomposableAttention(vocab, embed_size, num_hiddens)
net.initialize(init.Xavier(), ctx=devices)
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.set_data(embeds)
[21:49:49] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
[21:49:49] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
Downloading ../data/glove.6B.100d.zip from http://d2l-data.s3-accelerate.amazonaws.com/glove.6B.100d.zip...
16.5.2.3. 训练和评估模型¶
与 第 13.5 节 中的 split_batch
函数不同,后者接收单个输入,如文本序列(或图像),我们定义一个 split_batch_multi_inputs
函数来接收多个输入,例如小批量中的前提和假设。
#@save
def split_batch_multi_inputs(X, y, devices):
"""Split multi-input `X` and `y` into multiple devices."""
X = list(zip(*[gluon.utils.split_and_load(
feature, devices, even_split=False) for feature in X]))
return (X, gluon.utils.split_and_load(y, devices, even_split=False))
现在我们可以在SNLI数据集上训练和评估模型。
lr, num_epochs = 0.001, 4
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.496, train acc 0.805, test acc 0.828
20383.2 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
lr, num_epochs = 0.001, 4
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
split_batch_multi_inputs)
loss 0.514, train acc 0.797, test acc 0.814
4621.6 examples/sec on [gpu(0), gpu(1)]
16.5.2.4. 使用模型¶
最后,定义预测函数以输出一对前提和假设之间的逻辑关系。
#@save
def predict_snli(net, vocab, premise, hypothesis):
"""Predict the logical relationship between the premise and hypothesis."""
net.eval()
premise = torch.tensor(vocab[premise], device=d2l.try_gpu())
hypothesis = torch.tensor(vocab[hypothesis], device=d2l.try_gpu())
label = torch.argmax(net([premise.reshape((1, -1)),
hypothesis.reshape((1, -1))]), dim=1)
return 'entailment' if label == 0 else 'contradiction' if label == 1 \
else 'neutral'
#@save
def predict_snli(net, vocab, premise, hypothesis):
"""Predict the logical relationship between the premise and hypothesis."""
premise = np.array(vocab[premise], ctx=d2l.try_gpu())
hypothesis = np.array(vocab[hypothesis], ctx=d2l.try_gpu())
label = np.argmax(net([premise.reshape((1, -1)),
hypothesis.reshape((1, -1))]), axis=1)
return 'entailment' if label == 0 else 'contradiction' if label == 1 \
else 'neutral'
我们可以使用训练好的模型来获得一对示例文本的自然语言推断结果。
predict_snli(net, vocab, ['he', 'is', 'good', '.'], ['he', 'is', 'bad', '.'])
'contradiction'
predict_snli(net, vocab, ['he', 'is', 'good', '.'], ['he', 'is', 'bad', '.'])
'contradiction'
16.5.3. 小结¶
可分解注意力模型包含三个步骤来预测前提和假设之间的逻辑关系:关注、比较和聚合。
通过注意力机制,我们可以将一个文本序列中的词元与另一个文本序列中的每个词元对齐,反之亦然。这种对齐是使用加权平均的软对齐,理想情况下,与要对齐的词元相关联的权重很大。
在计算注意力权重时,分解技巧带来了比二次复杂度更理想的线性复杂度。
我们可以使用预训练的词向量作为下游自然语言处理任务(如自然语言推断)的输入表示。