16.3. 情感分析:使用卷积神经网络¶ 在 SageMaker Studio Lab 中打开 Notebook
在 第 7 节中,我们探讨了使用二维卷积神经网络处理二维图像数据的方法,并将其应用于局部特征,如相邻像素。虽然卷积神经网络最初是为计算机视觉设计的,但它们也被广泛用于自然语言处理。简单地说,只要将任何文本序列看作是一维图像即可。这样,一维卷积神经网络就可以处理文本中的局部特征,例如\(n\)-grams。
在本节中,我们将使用 *textCNN* 模型来演示如何设计一个表示单个文本的卷积神经网络架构 (Kim, 2014)。与 图 16.2.1 中使用带有 GloVe 预训练的 RNN 架构进行情感分析相比,图 16.3.1 中唯一的不同在于架构的选择。
图 16.3.1 本节将预训练的 GloVe 输入到基于 CNN 的情感分析架构中。¶
import torch
from torch import nn
from d2l import torch as d2l
batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)
from mxnet import gluon, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)
[22:21:24] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
16.3.1. 一维卷积¶
在介绍该模型之前,我们先来看看一维卷积是如何工作的。请记住,它只是基于互相关运算的二维卷积的一个特例。
图 16.3.2 一维互相关运算。阴影部分是第一个输出元素以及用于输出计算的输入和核张量元素:\(0\times1+1\times2=2\)。¶
如 图 16.3.2 所示,在一维情况下,卷积窗口从左到右滑过输入张量。在滑动过程中,特定位置卷积窗口中包含的输入子张量(例如,图 16.3.2 中的 \(0\) 和 \(1\))和核张量(例如,图 16.3.2 中的 \(1\) 和 \(2\))按元素相乘。这些乘积的和得到输出张量对应位置的单个标量值(例如,图 16.3.2 中的 \(0\times1+1\times2=2\))。
我们在下面的 corr1d
函数中实现一维互相关。给定一个输入张量 X
和一个核张量 K
,它返回输出张量 Y
。
def corr1d(X, K):
w = K.shape[0]
Y = torch.zeros((X.shape[0] - w + 1))
for i in range(Y.shape[0]):
Y[i] = (X[i: i + w] * K).sum()
return Y
def corr1d(X, K):
w = K.shape[0]
Y = np.zeros((X.shape[0] - w + 1))
for i in range(Y.shape[0]):
Y[i] = (X[i: i + w] * K).sum()
return Y
我们可以从 图 16.3.2 构造输入张量 X
和核张量 K
,以验证上述一维互相关实现的输出。
X, K = torch.tensor([0, 1, 2, 3, 4, 5, 6]), torch.tensor([1, 2])
corr1d(X, K)
tensor([ 2., 5., 8., 11., 14., 17.])
X, K = np.array([0, 1, 2, 3, 4, 5, 6]), np.array([1, 2])
corr1d(X, K)
array([ 2., 5., 8., 11., 14., 17.])
对于任何具有多个通道的一维输入,卷积核需要具有相同数量的输入通道。然后,对于每个通道,对输入的一维张量和卷积核的一维张量执行互相关运算,并将所有通道的结果相加,以生成一维输出张量。图 16.3.3 显示了具有 3 个输入通道的一维互相关运算。
图 16.3.3 具有3个输入通道的一维互相关运算。阴影部分是第一个输出元素以及用于输出计算的输入和核张量元素:\(0\times1+1\times2+1\times3+2\times4+2\times(-1)+3\times(-3)=2\)。¶
我们可以为多个输入通道实现一维互相关运算,并验证 图 16.3.3 中的结果。
def corr1d_multi_in(X, K):
# First, iterate through the 0th dimension (channel dimension) of `X` and
# `K`. Then, add them together
return sum(corr1d(x, k) for x, k in zip(X, K))
X = torch.tensor([[0, 1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 6, 7],
[2, 3, 4, 5, 6, 7, 8]])
K = torch.tensor([[1, 2], [3, 4], [-1, -3]])
corr1d_multi_in(X, K)
tensor([ 2., 8., 14., 20., 26., 32.])
def corr1d_multi_in(X, K):
# First, iterate through the 0th dimension (channel dimension) of `X` and
# `K`. Then, add them together
return sum(corr1d(x, k) for x, k in zip(X, K))
X = np.array([[0, 1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 6, 7],
[2, 3, 4, 5, 6, 7, 8]])
K = np.array([[1, 2], [3, 4], [-1, -3]])
corr1d_multi_in(X, K)
array([ 2., 8., 14., 20., 26., 32.])
请注意,多输入通道的一维互相关等价于单输入通道的二维互相关。为了说明这一点,图 16.3.3 中多输入通道的一维互相关的一种等价形式是 图 16.3.4 中的单输入通道二维互相关,其中卷积核的高度必须与输入张量的高度相同。
图 16.3.4 具有单个输入通道的二维互相关运算。阴影部分是第一个输出元素以及用于输出计算的输入和核张量元素:\(2\times(-1)+3\times(-3)+1\times3+2\times4+0\times1+1\times2=2\)。¶
图 16.3.2 和 图 16.3.3 中的输出都只有一个通道。与 7.4.2节 中描述的具有多个输出通道的二维卷积一样,我们也可以为一维卷积指定多个输出通道。
16.3.2. 时序最大汇聚¶
类似地,我们可以使用汇聚层从序列表示中提取最高值,作为跨时间步最重要的特征。*textCNN* 中使用的*时序最大汇聚*(max-over-time pooling)的作用类似于一维全局最大汇聚 (Collobert et al., 2011)。对于一个多通道输入,其中每个通道存储不同时间步的值,每个通道的输出是该通道的最大值。请注意,时序最大汇聚允许不同通道上有不同数量的时间步。
16.3.3. textCNN 模型¶
利用一维卷积和时序最大汇聚,textCNN模型以单个预训练的词元表示为输入,然后获取并转换序列表示以用于下游应用。
对于由 \(d\) 维向量表示的 \(n\) 个词元组成的单个文本序列,输入张量的宽度、高度和通道数分别为 \(n\)、\(1\) 和 \(d\)。textCNN 模型将输入转换为输出的过程如下:
定义多个一维卷积核,并分别对输入执行卷积运算。不同宽度的卷积核可以捕获不同数量相邻词元之间的局部特征。
在所有输出通道上执行时序最大汇聚,然后将所有标量汇聚输出连接成一个向量。
使用全连接层将连接后的向量转换为输出类别。可以使用 Dropout 来减少过拟合。
图 16.3.5 textCNN的模型架构。¶
图 16.3.5 用一个具体例子说明了 textCNN 的模型架构。输入是一个包含 11 个词元的句子,每个词元都由一个 6 维向量表示。所以我们有一个 6 通道、宽度为 11 的输入。定义两个宽度分别为 2 和 4 的一维卷积核,它们分别有 4 和 5 个输出通道。它们分别产生 4 个宽度为 \(11-2+1=10\) 的输出通道和 5 个宽度为 \(11-4+1=8\) 的输出通道。尽管这 9 个通道的宽度不同,时序最大汇聚给出了一个连接后的 9 维向量,最终被转换为一个 2 维输出向量,用于二元情感预测。
16.3.3.1. 定义模型¶
我们在下面的类中实现 textCNN 模型。与 16.2节 中的双向 RNN 模型相比,除了用卷积层替换循环层外,我们还使用了两个嵌入层:一个权重可训练,另一个权重固定。
class TextCNN(nn.Module):
def __init__(self, vocab_size, embed_size, kernel_sizes, num_channels,
**kwargs):
super(TextCNN, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
# The embedding layer not to be trained
self.constant_embedding = nn.Embedding(vocab_size, embed_size)
self.dropout = nn.Dropout(0.5)
self.decoder = nn.Linear(sum(num_channels), 2)
# The max-over-time pooling layer has no parameters, so this instance
# can be shared
self.pool = nn.AdaptiveAvgPool1d(1)
self.relu = nn.ReLU()
# Create multiple one-dimensional convolutional layers
self.convs = nn.ModuleList()
for c, k in zip(num_channels, kernel_sizes):
self.convs.append(nn.Conv1d(2 * embed_size, c, k))
def forward(self, inputs):
# Concatenate two embedding layer outputs with shape (batch size, no.
# of tokens, token vector dimension) along vectors
embeddings = torch.cat((
self.embedding(inputs), self.constant_embedding(inputs)), dim=2)
# Per the input format of one-dimensional convolutional layers,
# rearrange the tensor so that the second dimension stores channels
embeddings = embeddings.permute(0, 2, 1)
# For each one-dimensional convolutional layer, after max-over-time
# pooling, a tensor of shape (batch size, no. of channels, 1) is
# obtained. Remove the last dimension and concatenate along channels
encoding = torch.cat([
torch.squeeze(self.relu(self.pool(conv(embeddings))), dim=-1)
for conv in self.convs], dim=1)
outputs = self.decoder(self.dropout(encoding))
return outputs
class TextCNN(nn.Block):
def __init__(self, vocab_size, embed_size, kernel_sizes, num_channels,
**kwargs):
super(TextCNN, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
# The embedding layer not to be trained
self.constant_embedding = nn.Embedding(vocab_size, embed_size)
self.dropout = nn.Dropout(0.5)
self.decoder = nn.Dense(2)
# The max-over-time pooling layer has no parameters, so this instance
# can be shared
self.pool = nn.GlobalMaxPool1D()
# Create multiple one-dimensional convolutional layers
self.convs = nn.Sequential()
for c, k in zip(num_channels, kernel_sizes):
self.convs.add(nn.Conv1D(c, k, activation='relu'))
def forward(self, inputs):
# Concatenate two embedding layer outputs with shape (batch size, no.
# of tokens, token vector dimension) along vectors
embeddings = np.concatenate((
self.embedding(inputs), self.constant_embedding(inputs)), axis=2)
# Per the input format of one-dimensional convolutional layers,
# rearrange the tensor so that the second dimension stores channels
embeddings = embeddings.transpose(0, 2, 1)
# For each one-dimensional convolutional layer, after max-over-time
# pooling, a tensor of shape (batch size, no. of channels, 1) is
# obtained. Remove the last dimension and concatenate along channels
encoding = np.concatenate([
np.squeeze(self.pool(conv(embeddings)), axis=-1)
for conv in self.convs], axis=1)
outputs = self.decoder(self.dropout(encoding))
return outputs
让我们创建一个 textCNN 实例。它有 3 个卷积层,核宽度分别为 3、4 和 5,都具有 100 个输出通道。
embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
devices = d2l.try_all_gpus()
net = TextCNN(len(vocab), embed_size, kernel_sizes, nums_channels)
def init_weights(module):
if type(module) in (nn.Linear, nn.Conv1d):
nn.init.xavier_uniform_(module.weight)
net.apply(init_weights);
embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
devices = d2l.try_all_gpus()
net = TextCNN(len(vocab), embed_size, kernel_sizes, nums_channels)
net.initialize(init.Xavier(), ctx=devices)
[22:21:29] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
[22:21:30] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for GPU
16.3.3.2. 加载预训练词向量¶
与 16.2节 相同,我们加载预训练的 100 维 GloVe 嵌入作为初始化的词元表示。这些词元表示(嵌入权重)将在 embedding
中进行训练,并在 constant_embedding
中保持固定。
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.data.copy_(embeds)
net.constant_embedding.weight.data.copy_(embeds)
net.constant_embedding.weight.requires_grad = False
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.set_data(embeds)
net.constant_embedding.weight.set_data(embeds)
net.constant_embedding.collect_params().setattr('grad_req', 'null')
16.3.3.3. 训练和评估模型¶
现在我们可以训练 textCNN 模型进行情感分析了。
lr, num_epochs = 0.001, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.066, train acc 0.979, test acc 0.868
4354.2 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
lr, num_epochs = 0.001, 5
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.089, train acc 0.969, test acc 0.864
2425.2 examples/sec on [gpu(0), gpu(1)]
下面我们使用训练好的模型来预测两个简单句子的情感。
d2l.predict_sentiment(net, vocab, 'this movie is so great')
'positive'
d2l.predict_sentiment(net, vocab, 'this movie is so bad')
'negative'
d2l.predict_sentiment(net, vocab, 'this movie is so great')
'positive'
d2l.predict_sentiment(net, vocab, 'this movie is so bad')
'negative'
16.3.4. 总结¶
一维卷积神经网络可以处理文本中的局部特征,例如\(n\)-grams。
多输入通道的一维互相关等价于单输入通道的二维互相关。
时序最大汇聚允许不同通道上有不同数量的时间步。
textCNN 模型使用一维卷积层和时序最大汇聚层,将单个词元表示转换为下游应用的输出。