16.2. 情感分析:使用循环神经网络¶ 在 SageMaker Studio Lab 中打开 Notebook
与词相似度和类比任务一样,我们也可以将预训练的词向量应用于情感分析。由于 :numref:`sec_sentiment` 中的IMDb评论数据集不是很大,因此使用在超大规模语料库上预训练的文本表示可以减少模型的过拟合。如 :numref:`fig_nlp-map-sa-rnn` 中所示的具体示例,我们将使用预训练的GloVe模型来表示每个词元,并将这些词元表示输入到多层双向循环神经网络中,以获取文本序列表示,然后将其转换为情感分析输出 (Maas et al., 2011)。对于相同的下游应用,我们稍后将考虑一个不同的架构选择。
图 16.2.1 本节将预训练的GloVe输入到基于RNN的架构中进行情感分析。¶
import torch
from torch import nn
from d2l import torch as d2l
batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)
from mxnet import gluon, init, np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l
npx.set_np()
batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)
[22:13:34] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
16.2.1. 使用循环神经网络表示单个文本¶
在文本分类任务中(例如情感分析),可变长度的文本序列将被转换为固定长度的类别。在下面的 BiRNN
类中,虽然文本序列的每个词元都通过嵌入层(self.embedding
)获得了其单独的预训练GloVe表示,但整个序列由双向循环神经网络(self.encoder
)进行编码。更具体地说,双向LSTM在初始和最终时间步的隐藏状态(在最后一层)被连接起来作为文本序列的表示。然后,这个单一的文本表示被一个具有两个输出(“正面”和“负面”)的全连接层(self.decoder
)转换为输出类别。
class BiRNN(nn.Module):
def __init__(self, vocab_size, embed_size, num_hiddens,
num_layers, **kwargs):
super(BiRNN, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
# Set `bidirectional` to True to get a bidirectional RNN
self.encoder = nn.LSTM(embed_size, num_hiddens, num_layers=num_layers,
bidirectional=True)
self.decoder = nn.Linear(4 * num_hiddens, 2)
def forward(self, inputs):
# The shape of `inputs` is (batch size, no. of time steps). Because
# LSTM requires its input's first dimension to be the temporal
# dimension, the input is transposed before obtaining token
# representations. The output shape is (no. of time steps, batch size,
# word vector dimension)
embeddings = self.embedding(inputs.T)
self.encoder.flatten_parameters()
# Returns hidden states of the last hidden layer at different time
# steps. The shape of `outputs` is (no. of time steps, batch size,
# 2 * no. of hidden units)
outputs, _ = self.encoder(embeddings)
# Concatenate the hidden states at the initial and final time steps as
# the input of the fully connected layer. Its shape is (batch size,
# 4 * no. of hidden units)
encoding = torch.cat((outputs[0], outputs[-1]), dim=1)
outs = self.decoder(encoding)
return outs
class BiRNN(nn.Block):
def __init__(self, vocab_size, embed_size, num_hiddens,
num_layers, **kwargs):
super(BiRNN, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
# Set `bidirectional` to True to get a bidirectional RNN
self.encoder = rnn.LSTM(num_hiddens, num_layers=num_layers,
bidirectional=True, input_size=embed_size)
self.decoder = nn.Dense(2)
def forward(self, inputs):
# The shape of `inputs` is (batch size, no. of time steps). Because
# LSTM requires its input's first dimension to be the temporal
# dimension, the input is transposed before obtaining token
# representations. The output shape is (no. of time steps, batch size,
# word vector dimension)
embeddings = self.embedding(inputs.T)
# Returns hidden states of the last hidden layer at different time
# steps. The shape of `outputs` is (no. of time steps, batch size,
# 2 * no. of hidden units)
outputs = self.encoder(embeddings)
# Concatenate the hidden states at the initial and final time steps as
# the input of the fully connected layer. Its shape is (batch size,
# 4 * no. of hidden units)
encoding = np.concatenate((outputs[0], outputs[-1]), axis=1)
outs = self.decoder(encoding)
return outs
让我们构建一个具有两个隐藏层的双向循环神经网络来表示用于情感分析的单个文本。
embed_size, num_hiddens, num_layers, devices = 100, 100, 2, d2l.try_all_gpus()
net = BiRNN(len(vocab), embed_size, num_hiddens, num_layers)
def init_weights(module):
if type(module) == nn.Linear:
nn.init.xavier_uniform_(module.weight)
if type(module) == nn.LSTM:
for param in module._flat_weights_names:
if "weight" in param:
nn.init.xavier_uniform_(module._parameters[param])
net.apply(init_weights);
embed_size, num_hiddens, num_layers, devices = 100, 100, 2, d2l.try_all_gpus()
net = BiRNN(len(vocab), embed_size, num_hiddens, num_layers)
net.initialize(init.Xavier(), ctx=devices)
[22:13:40] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
[22:13:40] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
16.2.2. 加载预训练的词向量¶
下面我们加载预训练的100维(需要与 embed_size
一致)GloVe嵌入,用于词表中的词元。
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
打印词表中所有词元的向量形状。
embeds = glove_embedding[vocab.idx_to_token]
embeds.shape
torch.Size([49346, 100])
embeds = glove_embedding[vocab.idx_to_token]
embeds.shape
(49346, 100)
我们使用这些预训练的词向量来表示评论中的词元,并且在训练期间不会更新这些向量。
net.embedding.weight.data.copy_(embeds)
net.embedding.weight.requires_grad = False
net.embedding.weight.set_data(embeds)
net.embedding.collect_params().setattr('grad_req', 'null')
16.2.3. 训练和评估模型¶
现在我们可以训练双向循环神经网络进行情感分析了。
lr, num_epochs = 0.01, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.277, train acc 0.884, test acc 0.861
2608.4 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
lr, num_epochs = 0.01, 5
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.305, train acc 0.867, test acc 0.852
822.5 examples/sec on [gpu(0), gpu(1)]
我们定义以下函数来使用训练好的模型 net
预测文本序列的情感。
#@save
def predict_sentiment(net, vocab, sequence):
"""Predict the sentiment of a text sequence."""
sequence = torch.tensor(vocab[sequence.split()], device=d2l.try_gpu())
label = torch.argmax(net(sequence.reshape(1, -1)), dim=1)
return 'positive' if label == 1 else 'negative'
#@save
def predict_sentiment(net, vocab, sequence):
"""Predict the sentiment of a text sequence."""
sequence = np.array(vocab[sequence.split()], ctx=d2l.try_gpu())
label = np.argmax(net(sequence.reshape(1, -1)), axis=1)
return 'positive' if label == 1 else 'negative'
最后,让我们用训练好的模型来预测两个简单句子的情感。
predict_sentiment(net, vocab, 'this movie is so great')
'positive'
predict_sentiment(net, vocab, 'this movie is so bad')
'negative'
predict_sentiment(net, vocab, 'this movie is so great')
'positive'
predict_sentiment(net, vocab, 'this movie is so bad')
'negative'
16.2.4. 小结¶
预训练的词向量可以表示文本序列中的单个词元。
双向循环神经网络可以表示文本序列,例如通过连接其在初始和最终时间步的隐藏状态。这个单一的文本表示可以通过全连接层转换为类别。
16.2.5. 练习¶
增加迭代次数。你能提高训练和测试的准确率吗?调整其他超参数呢?
使用更大的预训练词向量,例如300维的GloVe嵌入。这能提高分类准确率吗?
我们可以通过使用spaCy分词来提高分类准确率吗?你需要安装spaCy(
pip install spacy
)并安装英文包(python -m spacy download en
)。在代码中,首先导入spaCy(import spacy
)。然后,加载spaCy英文包(spacy_en = spacy.load('en')
)。最后,定义函数def tokenizer(text): return [tok.text for tok in spacy_en.tokenizer(text)]
并替换原来的tokenizer
函数。注意GloVe和spaCy中短语词元的不同形式。例如,短语词元“new york”在GloVe中的形式是“new-york”,而在spaCy分词后的形式是“new york”。