10.2. 门控循环单元(GRU)¶ 在 SageMaker Studio Lab 中打开 Notebook
随着RNN(特别是LSTM架构(第 10.1 节))在2010年代迅速普及,许多研究人员开始尝试简化架构,希望保留包含内部状态和乘法门控机制的关键思想,但目标是加快计算速度。门控循环单元(GRU) (Cho et al., 2014) 提供了LSTM记忆单元的精简版本,通常能实现相当的性能,但其优势是计算速度更快 (Chung et al., 2014)。
import torch
from torch import nn
from d2l import torch as d2l
from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l
npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
import tensorflow as tf
from d2l import tensorflow as d2l
10.2.1. 重置门和更新门¶
在这里,LSTM的三个门被替换为两个:重置门(reset gate)和更新门(update gate)。与LSTM一样,这些门也使用sigmoid激活函数,将其值限制在区间 \((0, 1)\) 内。直观地说,重置门控制我们可能还想记住多少过去的状态。同样,更新门可以让我们控制新状态中有多少只是旧状态的副本。图 10.2.1 展示了在给定当前时间步的输入和前一时间步的隐藏状态下,GRU中重置门和更新门的输入。门的输出由两个带有sigmoid激活函数的全连接层给出。
图 10.2.1 计算GRU模型中的重置门和更新门。¶
在数学上,对于给定的时间步 \(t\),假设输入是一个小批量 \(\mathbf{X}_t \in \mathbb{R}^{n \times d}\)(样本数 \(=n\);输入数 \(=d\)),前一时间步的隐藏状态是 \(\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}\)(隐藏单元数 \(=h\))。那么,重置门 \(\mathbf{R}_t \in \mathbb{R}^{n \times h}\) 和更新门 \(\mathbf{Z}_t \in \mathbb{R}^{n \times h}\) 的计算如下
其中 \(\mathbf{W}_{\textrm{xr}}, \mathbf{W}_{\textrm{xz}} \in \mathbb{R}^{d \times h}\) 和 \(\mathbf{W}_{\textrm{hr}}, \mathbf{W}_{\textrm{hz}} \in \mathbb{R}^{h \times h}\) 是权重参数,\(\mathbf{b}_\textrm{r}, \mathbf{b}_\textrm{z} \in \mathbb{R}^{1 \times h}\) 是偏置参数。
10.2.4. 从零开始实现¶
为了更好地理解GRU模型,我们从头开始实现它。
10.2.4.1. 初始化模型参数¶
第一步是初始化模型参数。我们从标准差为 sigma
的高斯分布中抽取权重,并将偏置设置为0。超参数 num_hiddens
定义了隐藏单元的数量。我们实例化与更新门、重置门和候选隐藏状态相关的所有权重和偏置。
class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
nn.Parameter(torch.zeros(num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state
class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: np.random.randn(*shape) * sigma
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
np.zeros(num_hiddens))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state
class GRUScratch(d2l.Module):
num_inputs: int
num_hiddens: int
sigma: float = 0.01
def setup(self):
init_weight = lambda name, shape: self.param(name,
nn.initializers.normal(self.sigma),
shape)
triple = lambda name : (
init_weight(f'W_x{name}', (self.num_inputs, self.num_hiddens)),
init_weight(f'W_h{name}', (self.num_hiddens, self.num_hiddens)),
self.param(f'b_{name}', nn.initializers.zeros, (self.num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple('z') # Update gate
self.W_xr, self.W_hr, self.b_r = triple('r') # Reset gate
self.W_xh, self.W_hh, self.b_h = triple('h') # Candidate hidden state
class GRUScratch(d2l.Module):
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: tf.Variable(tf.random.normal(shape) * sigma)
triple = lambda: (init_weight(num_inputs, num_hiddens),
init_weight(num_hiddens, num_hiddens),
tf.Variable(tf.zeros(num_hiddens)))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state
10.2.4.2. 定义模型¶
现在我们准备定义GRU的前向计算。其结构与基本RNN单元相同,只是更新方程更为复杂。
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = torch.zeros((inputs.shape[1], self.num_hiddens),
device=inputs.device)
outputs = []
for X in inputs:
Z = torch.sigmoid(torch.matmul(X, self.W_xz) +
torch.matmul(H, self.W_hz) + self.b_z)
R = torch.sigmoid(torch.matmul(X, self.W_xr) +
torch.matmul(H, self.W_hr) + self.b_r)
H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +
torch.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = np.zeros((inputs.shape[1], self.num_hiddens),
ctx=inputs.ctx)
outputs = []
for X in inputs:
Z = npx.sigmoid(np.dot(X, self.W_xz) +
np.dot(H, self.W_hz) + self.b_z)
R = npx.sigmoid(np.dot(X, self.W_xr) +
np.dot(H, self.W_hr) + self.b_r)
H_tilde = np.tanh(np.dot(X, self.W_xh) +
np.dot(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
# Use lax.scan primitive instead of looping over the
# inputs, since scan saves time in jit compilation
def scan_fn(H, X):
Z = jax.nn.sigmoid(jnp.matmul(X, self.W_xz) + jnp.matmul(H, self.W_hz) +
self.b_z)
R = jax.nn.sigmoid(jnp.matmul(X, self.W_xr) +
jnp.matmul(H, self.W_hr) + self.b_r)
H_tilde = jnp.tanh(jnp.matmul(X, self.W_xh) +
jnp.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
return H, H # return carry, y
if H is None:
batch_size = inputs.shape[1]
carry = jnp.zeros((batch_size, self.num_hiddens))
else:
carry = H
# scan takes the scan_fn, initial carry state, xs with leading axis to be scanned
carry, outputs = jax.lax.scan(scan_fn, carry, inputs)
return outputs, carry
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = tf.zeros((inputs.shape[1], self.num_hiddens))
outputs = []
for X in inputs:
Z = tf.sigmoid(tf.matmul(X, self.W_xz) +
tf.matmul(H, self.W_hz) + self.b_z)
R = tf.sigmoid(tf.matmul(X, self.W_xr) +
tf.matmul(H, self.W_hr) + self.b_r)
H_tilde = tf.tanh(tf.matmul(X, self.W_xh) +
tf.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
outputs.append(H)
return outputs, H
10.2.4.3. 训练¶
在《时间机器》数据集上训练语言模型的方式与 第 9.5 节 中完全相同。
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
with d2l.try_gpu():
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1)
trainer.fit(model, data)
10.2.5. 简洁实现¶
在高级API中,我们可以直接实例化一个GRU模型。这封装了我们上面明确说明的所有配置细节。
class GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = nn.GRU(num_inputs, num_hiddens)
class GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = rnn.GRU(num_hiddens)
class GRU(d2l.RNN):
num_hiddens: int
@nn.compact
def __call__(self, inputs, H=None, training=False):
if H is None:
batch_size = inputs.shape[1]
H = nn.GRUCell.initialize_carry(jax.random.PRNGKey(0),
(batch_size,), self.num_hiddens)
GRU = nn.scan(nn.GRUCell, variable_broadcast="params",
in_axes=0, out_axes=0, split_rngs={"params": False})
H, outputs = GRU()(H, inputs)
return outputs, H
class GRU(d2l.RNN):
def __init__(self, num_inputs, num_hiddens):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = tf.keras.layers.GRU(num_hiddens, return_sequences=True,
return_state=True)
由于使用了编译好的算子而不是Python,代码在训练时速度要快得多。
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
gru = GRU(num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
with d2l.try_gpu():
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
训练后,我们打印出训练集上的困惑度以及根据提供的前缀预测的序列。
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has so it and the time '
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has i have the time tra'
model.predict('it has', 20, data.vocab, trainer.state.params)
'it has is a move and a mov'
model.predict('it has', 20, data.vocab)
'it has t t t t t t t t t t'
10.2.6. 总结¶
与LSTM相比,GRU实现了相似的性能,但计算上更轻量。通常,与简单RNN相比,像LSTM和GRU这样的门控RNN能更好地捕捉时间步距离较大的序列的依赖关系。当重置门被激活时,GRU包含基本RNN作为其极端情况。它们还可以通过激活更新门来跳过子序列。