10.1. 长短期记忆网络(LSTM)¶ 在 SageMaker Studio Lab 中打开 Notebook
在使用反向传播 (Elman, 1990) 训练了第一个Elman风格的循环神经网络之后不久,学习长期依赖(由于梯度消失和爆炸)的问题就变得很突出了,Bengio和Hochreiter讨论了这个问题 (Bengio et al., 1994, Hochreiter et al., 2001)。Hochreiter早在1991年就在他的硕士论文中阐明了这个问题,但由于论文是用德语写的,其结果并不广为人知。虽然梯度裁剪有助于解决梯度爆炸问题,但处理梯度消失问题似乎需要更精细的解决方案。最早、也是最成功的解决梯度消失的技术之一是长短期记忆(LSTM)模型,它由 Hochreiter和Schmidhuber(1997) 提出。LSTM与标准的循环神经网络相似,但这里每个普通的循环节点都被一个*记忆单元*所取代。每个记忆单元都包含一个*内部状态*,即一个带有固定权重为1的自连接循环边的节点,确保梯度可以跨越多个时间步而不发生消失或爆炸。
“长短期记忆”这个术语源于以下直觉。简单的循环神经网络具有权值形式的*长期记忆*。权重在训练过程中变化缓慢,编码了关于数据的一般知识。它们还具有短暂激活形式的*短期记忆*,从每个节点传递到后续节点。LSTM模型通过记忆单元引入了一种中间类型的存储。记忆单元是一个复合单元,由特定连接模式下的简单节点构建而成,并新颖地加入了乘法节点。
import torch
from torch import nn
from d2l import torch as d2l
from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l
npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
import tensorflow as tf
from d2l import tensorflow as d2l
10.1.1. 门控记忆单元¶
每个记忆单元都配备了一个*内部状态*和多个乘法门,这些门决定了 (i) 一个给定的输入是否应该影响内部状态(*输入门*),(ii) 内部状态是否应该被刷新为\(0\)(*遗忘门*),以及 (iii) 一个给定神经元的内部状态是否被允许影响单元的输出(*输出门*)。
10.1.1.2. 输入门、遗忘门和输出门¶
输入到LSTM门的数据是当前时间步的输入和前一个时间步的隐藏状态,如 图 10.1.1 所示。三个具有sigmoid激活函数的全连接层计算输入门、遗忘门和输出门的值。由于sigmoid激活的结果,这三个门的所有值都在 \((0, 1)\) 的范围内。此外,我们还需要一个*输入节点*,通常用*tanh*激活函数计算。直观地说,*输入门*决定了输入节点的值有多少应该被添加到当前记忆单元的内部状态中。*遗忘门*决定是保留记忆单元的当前值还是将其刷新。而*输出门*决定了记忆单元是否应该影响当前时间步的输出。
图 10.1.1 在LSTM模型中计算输入门、遗忘门和输出门。¶
数学上,假设有 \(h\) 个隐藏单元,批量大小为 \(n\),输入数量为 \(d\)。因此,输入是 \(\mathbf{X}_t \in \mathbb{R}^{n \times d}\),前一个时间步的隐藏状态是 \(\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}\)。相应地,时间步 \(t\) 的门定义如下:输入门为 \(\mathbf{I}_t \in \mathbb{R}^{n \times h}\),遗忘门为 \(\mathbf{F}_t \in \mathbb{R}^{n \times h}\),输出门为 \(\mathbf{O}_t \in \mathbb{R}^{n \times h}\)。它们的计算方式如下:
其中 \(\mathbf{W}_{\textrm{xi}}, \mathbf{W}_{\textrm{xf}}, \mathbf{W}_{\textrm{xo}} \in \mathbb{R}^{d \times h}\) 和 \(\mathbf{W}_{\textrm{hi}}, \mathbf{W}_{\textrm{hf}}, \mathbf{W}_{\textrm{ho}} \in \mathbb{R}^{h \times h}\) 是权重参数,\(\mathbf{b}_\textrm{i}, \mathbf{b}_\textrm{f}, \mathbf{b}_\textrm{o} \in \mathbb{R}^{1 \times h}\) 是偏置参数。请注意,在求和期间会触发广播(参见 2.1.4节)。我们使用sigmoid函数(如 5.1节 中介绍的)将输入值映射到区间 \((0, 1)\)。
10.1.1.3. 输入节点¶
接下来我们设计记忆单元。由于我们还没有指定各种门的作用,我们首先引入*输入节点* \(\tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h}\)。它的计算与上述三个门类似,但使用 \(\tanh\) 函数作为激活函数,其值域为 \((-1, 1)\)。这导致在时间步 \(t\) 有以下方程:
其中 \(\mathbf{W}_{\textrm{xc}} \in \mathbb{R}^{d \times h}\) 和 \(\mathbf{W}_{\textrm{hc}} \in \mathbb{R}^{h \times h}\) 是权重参数,\(\mathbf{b}_\textrm{c} \in \mathbb{R}^{1 \times h}\) 是偏置参数。
输入节点的简要图示见 图 10.1.2。
图 10.1.2 在LSTM模型中计算输入节点。¶
10.1.1.4. 记忆单元内部状态¶
在LSTM中,输入门 \(\mathbf{I}_t\) 控制我们通过 \(\tilde{\mathbf{C}}_t\) 考虑新数据的程度,而遗忘门 \(\mathbf{F}_t\) 决定我们保留旧记忆单元内部状态 \(\mathbf{C}_{t-1} \in \mathbb{R}^{n \times h}\) 的程度。使用哈达玛(按元素)积运算符 \(\odot\),我们得到以下更新方程:
如果遗忘门始终为1,输入门始终为0,则记忆单元的内部状态 \(\mathbf{C}_{t-1}\) 将永远保持不变,不加改变地传递到每个后续时间步。然而,输入门和遗忘门为模型提供了灵活性,使其能够学习何时保持该值不变,以及何时响应后续输入对其进行扰动。在实践中,这种设计减轻了梯度消失问题,从而使模型更容易训练,尤其是在面对具有长序列长度的数据集时。
我们因此得到了 图 10.1.3 中的流程图。
图 10.1.3 在LSTM模型中计算记忆单元内部状态。¶
10.1.2. 从零开始实现¶
现在让我们从零开始实现一个LSTM。与 9.5节 中的实验一样,我们首先加载《时间机器》数据集。
10.1.2.1. 初始化模型参数¶
接下来,我们需要定义和初始化模型参数。和之前一样,超参数 num_hiddens
决定了隐藏单元的数量。我们按照标准差为0.01的高斯分布来初始化权重,并将偏置设为0。
class LSTMScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
nn.Parameter(torch.zeros(num_hiddens)))
self.W_xi, self.W_hi, self.b_i = triple() # Input gate
self.W_xf, self.W_hf, self.b_f = triple() # Forget gate
self.W_xo, self.W_ho, self.b_o = triple() # Output gate
self.W_xc, self.W_hc, self.b_c = triple() # Input node
实际的模型如上所述定义,由三个门和一个输入节点组成。请注意,只有隐藏状态被传递到输出层。
@d2l.add_to_class(LSTMScratch)
def forward(self, inputs, H_C=None):
if H_C is None:
# Initial state with shape: (batch_size, num_hiddens)
H = torch.zeros((inputs.shape[1], self.num_hiddens),
device=inputs.device)
C = torch.zeros((inputs.shape[1], self.num_hiddens),
device=inputs.device)
else:
H, C = H_C
outputs = []
for X in inputs:
I = torch.sigmoid(torch.matmul(X, self.W_xi) +
torch.matmul(H, self.W_hi) + self.b_i)
F = torch.sigmoid(torch.matmul(X, self.W_xf) +
torch.matmul(H, self.W_hf) + self.b_f)
O = torch.sigmoid(torch.matmul(X, self.W_xo) +
torch.matmul(H, self.W_ho) + self.b_o)
C_tilde = torch.tanh(torch.matmul(X, self.W_xc) +
torch.matmul(H, self.W_hc) + self.b_c)
C = F * C + I * C_tilde
H = O * torch.tanh(C)
outputs.append(H)
return outputs, (H, C)
class LSTMScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: np.random.randn(*shape) * sigma
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
np.zeros(num_hiddens))
self.W_xi, self.W_hi, self.b_i = triple() # Input gate
self.W_xf, self.W_hf, self.b_f = triple() # Forget gate
self.W_xo, self.W_ho, self.b_o = triple() # Output gate
self.W_xc, self.W_hc, self.b_c = triple() # Input node
实际的模型如上所述定义,由三个门和一个输入节点组成。请注意,只有隐藏状态被传递到输出层。
@d2l.add_to_class(LSTMScratch)
def forward(self, inputs, H_C=None):
if H_C is None:
# Initial state with shape: (batch_size, num_hiddens)
H = np.zeros((inputs.shape[1], self.num_hiddens),
ctx=inputs.ctx)
C = np.zeros((inputs.shape[1], self.num_hiddens),
ctx=inputs.ctx)
else:
H, C = H_C
outputs = []
for X in inputs:
I = npx.sigmoid(np.dot(X, self.W_xi) +
np.dot(H, self.W_hi) + self.b_i)
F = npx.sigmoid(np.dot(X, self.W_xf) +
np.dot(H, self.W_hf) + self.b_f)
O = npx.sigmoid(np.dot(X, self.W_xo) +
np.dot(H, self.W_ho) + self.b_o)
C_tilde = np.tanh(np.dot(X, self.W_xc) +
np.dot(H, self.W_hc) + self.b_c)
C = F * C + I * C_tilde
H = O * np.tanh(C)
outputs.append(H)
return outputs, (H, C)
class LSTMScratch(d2l.Module):
num_inputs: int
num_hiddens: int
sigma: float = 0.01
def setup(self):
init_weight = lambda name, shape: self.param(name,
nn.initializers.normal(self.sigma),
shape)
triple = lambda name : (
init_weight(f'W_x{name}', (self.num_inputs, self.num_hiddens)),
init_weight(f'W_h{name}', (self.num_hiddens, self.num_hiddens)),
self.param(f'b_{name}', nn.initializers.zeros, (self.num_hiddens)))
self.W_xi, self.W_hi, self.b_i = triple('i') # Input gate
self.W_xf, self.W_hf, self.b_f = triple('f') # Forget gate
self.W_xo, self.W_ho, self.b_o = triple('o') # Output gate
self.W_xc, self.W_hc, self.b_c = triple('c') # Input node
实际的模型如上所述定义,由三个门和一个输入节点组成。请注意,只有隐藏状态被传递到输出层。forward
方法中的长for循环会导致首次运行时JIT编译时间非常长。作为解决方案,JAX使用jax.lax.scan
实用转换来实现相同的行为,而不是使用for循环来在每个时间步更新状态。它接收一个名为carry
的初始状态和一个inputs
数组,该数组在其主轴上被扫描。scan
转换最终会按预期返回最终状态和堆叠的输出。
@d2l.add_to_class(LSTMScratch)
def forward(self, inputs, H_C=None):
# Use lax.scan primitive instead of looping over the
# inputs, since scan saves time in jit compilation.
def scan_fn(carry, X):
H, C = carry
I = jax.nn.sigmoid(jnp.matmul(X, self.W_xi) + (
jnp.matmul(H, self.W_hi)) + self.b_i)
F = jax.nn.sigmoid(jnp.matmul(X, self.W_xf) +
jnp.matmul(H, self.W_hf) + self.b_f)
O = jax.nn.sigmoid(jnp.matmul(X, self.W_xo) +
jnp.matmul(H, self.W_ho) + self.b_o)
C_tilde = jnp.tanh(jnp.matmul(X, self.W_xc) +
jnp.matmul(H, self.W_hc) + self.b_c)
C = F * C + I * C_tilde
H = O * jnp.tanh(C)
return (H, C), H # return carry, y
if H_C is None:
batch_size = inputs.shape[1]
carry = jnp.zeros((batch_size, self.num_hiddens)), \
jnp.zeros((batch_size, self.num_hiddens))
else:
carry = H_C
# scan takes the scan_fn, initial carry state, xs with leading axis to be scanned
carry, outputs = jax.lax.scan(scan_fn, carry, inputs)
return outputs, carry
class LSTMScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: tf.Variable(tf.random.normal(shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
tf.Variable(tf.zeros(num_hiddens)))
self.W_xi, self.W_hi, self.b_i = triple() # Input gate
self.W_xf, self.W_hf, self.b_f = triple() # Forget gate
self.W_xo, self.W_ho, self.b_o = triple() # Output gate
self.W_xc, self.W_hc, self.b_c = triple() # Input node
实际的模型如上所述定义,由三个门和一个输入节点组成。请注意,只有隐藏状态被传递到输出层。
@d2l.add_to_class(LSTMScratch)
def forward(self, inputs, H_C=None):
if H_C is None:
# Initial state with shape: (batch_size, num_hiddens)
H = tf.zeros((inputs.shape[1], self.num_hiddens))
C = tf.zeros((inputs.shape[1], self.num_hiddens))
else:
H, C = H_C
outputs = []
for X in inputs:
I = tf.sigmoid(tf.matmul(X, self.W_xi) +
tf.matmul(H, self.W_hi) + self.b_i)
F = tf.sigmoid(tf.matmul(X, self.W_xf) +
tf.matmul(H, self.W_hf) + self.b_f)
O = tf.sigmoid(tf.matmul(X, self.W_xo) +
tf.matmul(H, self.W_ho) + self.b_o)
C_tilde = tf.tanh(tf.matmul(X, self.W_xc) +
tf.matmul(H, self.W_hc) + self.b_c)
C = F * C + I * C_tilde
H = O * tf.tanh(C)
outputs.append(H)
return outputs, (H, C)
10.1.2.2. 训练和预测¶
让我们通过实例化 9.5节 中的 RNNLMScratch
类来训练一个LSTM模型。
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
with d2l.try_gpu():
lstm = LSTMScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(lstm, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1)
trainer.fit(model, data)
10.1.3. 简洁实现¶
使用高级API,我们可以直接实例化一个LSTM模型。这封装了我们上面明确说明的所有配置细节。代码速度明显更快,因为它使用了编译好的运算符而不是Python来处理我们之前详细说明的许多细节。
class LSTM(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = nn.LSTM(num_inputs, num_hiddens)
def forward(self, inputs, H_C=None):
return self.rnn(inputs, H_C)
lstm = LSTM(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(lstm, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has a the time travelly'
class LSTM(d2l.RNN):
def __init__(self, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = rnn.LSTM(num_hiddens)
def forward(self, inputs, H_C=None):
if H_C is None: H_C = self.rnn.begin_state(
inputs.shape[1], ctx=inputs.ctx)
return self.rnn(inputs, H_C)
lstm = LSTM(num_hiddens=32)
model = d2l.RNNLM(lstm, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has all the time travel'
class LSTM(d2l.RNN):
num_hiddens: int
@nn.compact
def __call__(self, inputs, H_C=None, training=False):
if H_C is None:
batch_size = inputs.shape[1]
H_C = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0),
(batch_size,),
self.num_hiddens)
LSTM = nn.scan(nn.OptimizedLSTMCell, variable_broadcast="params",
in_axes=0, out_axes=0, split_rngs={"params": False})
H_C, outputs = LSTM()(H_C, inputs)
return outputs, H_C
lstm = LSTM(num_hiddens=32)
model = d2l.RNNLM(lstm, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
model.predict('it has', 20, data.vocab, trainer.state.params)
'it has and the pered han a'
class LSTM(d2l.RNN):
def __init__(self, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = tf.keras.layers.LSTM(
num_hiddens, return_sequences=True,
return_state=True, time_major=True)
def forward(self, inputs, H_C=None):
outputs, *H_C = self.rnn(inputs, H_C)
return outputs, H_C
lstm = LSTM(num_hiddens=32)
with d2l.try_gpu():
model = d2l.RNNLM(lstm, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
model.predict('it has', 20, data.vocab)
'it has a dimension a dimen'
LSTM是具有非平凡状态控制的典型潜变量自回归模型。多年来,人们提出了许多变体,例如多层、残差连接、不同类型的正则化。然而,由于序列的长期依赖性,训练LSTM和其他序列模型(如GRU)的成本相当高。稍后我们将遇到在某些情况下可以使用的替代模型,如Transformer。
10.1.4. 小结¶
虽然LSTM于1997年发表,但它们在2000年代中期的一些预测竞赛中取得胜利后声名鹊起,并从2011年到2017年Transformer模型兴起之前,成为序列学习的主导模型。即使是Transformer,其一些关键思想也归功于LSTM引入的架构设计创新。
LSTM有三种类型的门:控制信息流的输入门、遗忘门和输出门。LSTM的隐藏层输出包括隐藏状态和记忆单元内部状态。只有隐藏状态被传递到输出层,而记忆单元内部状态完全保持在内部。LSTM可以缓解梯度消失和爆炸的问题。