10.2. 门控循环单元(GRU)
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 Colab 中打开 Notebook
在 SageMaker Studio Lab 中打开 Notebook

随着RNN(特别是LSTM架构(第 10.1 节))在2010年代迅速普及,许多研究人员开始尝试简化架构,希望保留包含内部状态和乘法门控机制的关键思想,但目标是加快计算速度。门控循环单元(GRU) (Cho et al., 2014) 提供了LSTM记忆单元的精简版本,通常能实现相当的性能,但其优势是计算速度更快 (Chung et al., 2014)

import torch
from torch import nn
from d2l import torch as d2l
from mxnet import np, npx
from mxnet.gluon import rnn
from d2l import mxnet as d2l

npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
import tensorflow as tf
from d2l import tensorflow as d2l

10.2.1. 重置门和更新门

在这里,LSTM的三个门被替换为两个:重置门(reset gate)和更新门(update gate)。与LSTM一样,这些门也使用sigmoid激活函数,将其值限制在区间 \((0, 1)\) 内。直观地说,重置门控制我们可能还想记住多少过去的状态。同样,更新门可以让我们控制新状态中有多少只是旧状态的副本。图 10.2.1 展示了在给定当前时间步的输入和前一时间步的隐藏状态下,GRU中重置门和更新门的输入。门的输出由两个带有sigmoid激活函数的全连接层给出。

../_images/gru-1.svg

图 10.2.1 计算GRU模型中的重置门和更新门。

在数学上,对于给定的时间步 \(t\),假设输入是一个小批量 \(\mathbf{X}_t \in \mathbb{R}^{n \times d}\)(样本数 \(=n\);输入数 \(=d\)),前一时间步的隐藏状态是 \(\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}\)(隐藏单元数 \(=h\))。那么,重置门 \(\mathbf{R}_t \in \mathbb{R}^{n \times h}\) 和更新门 \(\mathbf{Z}_t \in \mathbb{R}^{n \times h}\) 的计算如下

(10.2.1)\[\begin{split}\begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{\textrm{xr}} + \mathbf{H}_{t-1} \mathbf{W}_{\textrm{hr}} + \mathbf{b}_\textrm{r}),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{\textrm{xz}} + \mathbf{H}_{t-1} \mathbf{W}_{\textrm{hz}} + \mathbf{b}_\textrm{z}), \end{aligned}\end{split}\]

其中 \(\mathbf{W}_{\textrm{xr}}, \mathbf{W}_{\textrm{xz}} \in \mathbb{R}^{d \times h}\)\(\mathbf{W}_{\textrm{hr}}, \mathbf{W}_{\textrm{hz}} \in \mathbb{R}^{h \times h}\) 是权重参数,\(\mathbf{b}_\textrm{r}, \mathbf{b}_\textrm{z} \in \mathbb{R}^{1 \times h}\) 是偏置参数。

10.2.2. 候选隐藏状态

接下来,我们将重置门 \(\mathbf{R}_t\)(9.4.5) 中的常规更新机制相结合,从而在时间步 \(t\) 得到以下候选隐藏状态 \(\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}\)

(10.2.2)\[\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{\textrm{xh}} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{\textrm{hh}} + \mathbf{b}_\textrm{h}),\]

其中 \(\mathbf{W}_{\textrm{xh}} \in \mathbb{R}^{d \times h}\)\(\mathbf{W}_{\textrm{hh}} \in \mathbb{R}^{h \times h}\) 是权重参数,\(\mathbf{b}_\textrm{h} \in \mathbb{R}^{1 \times h}\) 是偏置,符号 \(\odot\) 是哈达玛(按元素)乘积运算符。这里我们使用tanh激活函数。

结果是一个候选状态,因为我们仍然需要纳入更新门的操作。与 (9.4.5) 相比,在 (10.2.2) 中,通过 \(\mathbf{R}_t\)\(\mathbf{H}_{t-1}\) 的按元素相乘,现在可以减少先前状态的影响。每当重置门 \(\mathbf{R}_t\) 中的条目接近1时,我们就恢复一个像 (9.4.5) 中的普通RNN。对于重置门 \(\mathbf{R}_t\) 中所有接近0的条目,候选隐藏状态是以 \(\mathbf{X}_t\) 为输入的多层感知机的结果。因此,任何预先存在的隐藏状态都被重置为默认值。

图 10.2.2 展示了应用重置门后的计算流程。

../_images/gru-2.svg

图 10.2.2 计算GRU模型中的候选隐藏状态。

10.2.3. 隐藏状态

最后,我们需要整合更新门 \(\mathbf{Z}_t\) 的影响。这决定了新的隐藏状态 \(\mathbf{H}_t \in \mathbb{R}^{n \times h}\) 与旧状态 \(\mathbf{H}_{t-1}\) 的匹配程度,以及它与新的候选状态 \(\tilde{\mathbf{H}}_t\) 的相似程度。更新门 \(\mathbf{Z}_t\) 可以用于此目的,只需对 \(\mathbf{H}_{t-1}\)\(\tilde{\mathbf{H}}_t\) 进行按元素的凸组合。这导致了GRU的最终更新方程

(10.2.3)\[\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.\]

每当更新门 \(\mathbf{Z}_t\) 接近1时,我们简单地保留旧状态。在这种情况下,来自 \(\mathbf{X}_t\) 的信息被忽略,实际上是跳过了依赖链中的时间步 \(t\)。相反,当 \(\mathbf{Z}_t\) 接近0时,新的潜在状态 \(\mathbf{H}_t\) 接近于候选潜在状态 \(\tilde{\mathbf{H}}_t\)图 10.2.3 展示了更新门起作用后的计算流程。

../_images/gru-3.svg

图 10.2.3 计算GRU模型中的隐藏状态。

总而言之,GRU有两个显著的特点

  • 重置门有助于捕捉序列中的短期依赖关系。

  • 更新门有助于捕捉序列中的长期依赖关系。

10.2.4. 从零开始实现

为了更好地理解GRU模型,我们从头开始实现它。

10.2.4.1. 初始化模型参数

第一步是初始化模型参数。我们从标准差为 sigma 的高斯分布中抽取权重,并将偏置设置为0。超参数 num_hiddens 定义了隐藏单元的数量。我们实例化与更新门、重置门和候选隐藏状态相关的所有权重和偏置。

class GRUScratch(d2l.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()

        init_weight = lambda *shape: nn.Parameter(torch.randn(*shape) * sigma)
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          nn.Parameter(torch.zeros(num_hiddens)))
        self.W_xz, self.W_hz, self.b_z = triple()  # Update gate
        self.W_xr, self.W_hr, self.b_r = triple()  # Reset gate
        self.W_xh, self.W_hh, self.b_h = triple()  # Candidate hidden state
class GRUScratch(d2l.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()

        init_weight = lambda *shape: np.random.randn(*shape) * sigma
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          np.zeros(num_hiddens))
        self.W_xz, self.W_hz, self.b_z = triple()  # Update gate
        self.W_xr, self.W_hr, self.b_r = triple()  # Reset gate
        self.W_xh, self.W_hh, self.b_h = triple()  # Candidate hidden state
class GRUScratch(d2l.Module):
    num_inputs: int
    num_hiddens: int
    sigma: float = 0.01

    def setup(self):
        init_weight = lambda name, shape: self.param(name,
                                                     nn.initializers.normal(self.sigma),
                                                     shape)
        triple = lambda name : (
            init_weight(f'W_x{name}', (self.num_inputs, self.num_hiddens)),
            init_weight(f'W_h{name}', (self.num_hiddens, self.num_hiddens)),
            self.param(f'b_{name}', nn.initializers.zeros, (self.num_hiddens)))

        self.W_xz, self.W_hz, self.b_z = triple('z')  # Update gate
        self.W_xr, self.W_hr, self.b_r = triple('r')  # Reset gate
        self.W_xh, self.W_hh, self.b_h = triple('h')  # Candidate hidden state
class GRUScratch(d2l.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()

        init_weight = lambda *shape: tf.Variable(tf.random.normal(shape) * sigma)
        triple = lambda: (init_weight(num_inputs, num_hiddens),
                          init_weight(num_hiddens, num_hiddens),
                          tf.Variable(tf.zeros(num_hiddens)))

        self.W_xz, self.W_hz, self.b_z = triple()  # Update gate
        self.W_xr, self.W_hr, self.b_r = triple()  # Reset gate
        self.W_xh, self.W_hh, self.b_h = triple()  # Candidate hidden state

10.2.4.2. 定义模型

现在我们准备定义GRU的前向计算。其结构与基本RNN单元相同,只是更新方程更为复杂。

@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
    if H is None:
        # Initial state with shape: (batch_size, num_hiddens)
        H = torch.zeros((inputs.shape[1], self.num_hiddens),
                      device=inputs.device)
    outputs = []
    for X in inputs:
        Z = torch.sigmoid(torch.matmul(X, self.W_xz) +
                        torch.matmul(H, self.W_hz) + self.b_z)
        R = torch.sigmoid(torch.matmul(X, self.W_xr) +
                        torch.matmul(H, self.W_hr) + self.b_r)
        H_tilde = torch.tanh(torch.matmul(X, self.W_xh) +
                           torch.matmul(R * H, self.W_hh) + self.b_h)
        H = Z * H + (1 - Z) * H_tilde
        outputs.append(H)
    return outputs, H
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
    if H is None:
        # Initial state with shape: (batch_size, num_hiddens)
        H = np.zeros((inputs.shape[1], self.num_hiddens),
                      ctx=inputs.ctx)
    outputs = []
    for X in inputs:
        Z = npx.sigmoid(np.dot(X, self.W_xz) +
                        np.dot(H, self.W_hz) + self.b_z)
        R = npx.sigmoid(np.dot(X, self.W_xr) +
                        np.dot(H, self.W_hr) + self.b_r)
        H_tilde = np.tanh(np.dot(X, self.W_xh) +
                           np.dot(R * H, self.W_hh) + self.b_h)
        H = Z * H + (1 - Z) * H_tilde
        outputs.append(H)
    return outputs, H
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
    # Use lax.scan primitive instead of looping over the
    # inputs, since scan saves time in jit compilation
    def scan_fn(H, X):
        Z = jax.nn.sigmoid(jnp.matmul(X, self.W_xz) + jnp.matmul(H, self.W_hz) +
                        self.b_z)
        R = jax.nn.sigmoid(jnp.matmul(X, self.W_xr) +
                        jnp.matmul(H, self.W_hr) + self.b_r)
        H_tilde = jnp.tanh(jnp.matmul(X, self.W_xh) +
                           jnp.matmul(R * H, self.W_hh) + self.b_h)
        H = Z * H + (1 - Z) * H_tilde
        return H, H  # return carry, y

    if H is None:
        batch_size = inputs.shape[1]
        carry = jnp.zeros((batch_size, self.num_hiddens))
    else:
        carry = H

    # scan takes the scan_fn, initial carry state, xs with leading axis to be scanned
    carry, outputs = jax.lax.scan(scan_fn, carry, inputs)
    return outputs, carry
@d2l.add_to_class(GRUScratch)
def forward(self, inputs, H=None):
    if H is None:
        # Initial state with shape: (batch_size, num_hiddens)
        H = tf.zeros((inputs.shape[1], self.num_hiddens))
    outputs = []
    for X in inputs:
        Z = tf.sigmoid(tf.matmul(X, self.W_xz) +
                        tf.matmul(H, self.W_hz) + self.b_z)
        R = tf.sigmoid(tf.matmul(X, self.W_xr) +
                        tf.matmul(H, self.W_hr) + self.b_r)
        H_tilde = tf.tanh(tf.matmul(X, self.W_xh) +
                           tf.matmul(R * H, self.W_hh) + self.b_h)
        H = Z * H + (1 - Z) * H_tilde
        outputs.append(H)
    return outputs, H

10.2.4.3. 训练

在《时间机器》数据集上训练语言模型的方式与 第 9.5 节 中完全相同。

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
../_images/output_gru_b77a34_48_0.svg
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
../_images/output_gru_b77a34_51_0.svg
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)
../_images/output_gru_b77a34_54_0.svg
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
with d2l.try_gpu():
    gru = GRUScratch(num_inputs=len(data.vocab), num_hiddens=32)
    model = d2l.RNNLMScratch(gru, vocab_size=len(data.vocab), lr=4)
trainer = d2l.Trainer(max_epochs=50, gradient_clip_val=1)
trainer.fit(model, data)
../_images/output_gru_b77a34_57_0.svg

10.2.5. 简洁实现

在高级API中,我们可以直接实例化一个GRU模型。这封装了我们上面明确说明的所有配置细节。

class GRU(d2l.RNN):
    def __init__(self, num_inputs, num_hiddens):
        d2l.Module.__init__(self)
        self.save_hyperparameters()
        self.rnn = nn.GRU(num_inputs, num_hiddens)
class GRU(d2l.RNN):
    def __init__(self, num_inputs, num_hiddens):
        d2l.Module.__init__(self)
        self.save_hyperparameters()
        self.rnn = rnn.GRU(num_hiddens)
class GRU(d2l.RNN):
    num_hiddens: int

    @nn.compact
    def __call__(self, inputs, H=None, training=False):
        if H is None:
            batch_size = inputs.shape[1]
            H = nn.GRUCell.initialize_carry(jax.random.PRNGKey(0),
                                            (batch_size,), self.num_hiddens)

        GRU = nn.scan(nn.GRUCell, variable_broadcast="params",
                      in_axes=0, out_axes=0, split_rngs={"params": False})

        H, outputs = GRU()(H, inputs)
        return outputs, H
class GRU(d2l.RNN):
    def __init__(self, num_inputs, num_hiddens):
        d2l.Module.__init__(self)
        self.save_hyperparameters()
        self.rnn = tf.keras.layers.GRU(num_hiddens, return_sequences=True,
                                       return_state=True)

由于使用了编译好的算子而不是Python,代码在训练时速度要快得多。

gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
../_images/output_gru_b77a34_78_0.svg
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
../_images/output_gru_b77a34_81_0.svg
gru = GRU(num_hiddens=32)
model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
../_images/output_gru_b77a34_84_0.svg
gru = GRU(num_inputs=len(data.vocab), num_hiddens=32)
with d2l.try_gpu():
    model = d2l.RNNLM(gru, vocab_size=len(data.vocab), lr=4)
trainer.fit(model, data)
../_images/output_gru_b77a34_87_0.svg

训练后,我们打印出训练集上的困惑度以及根据提供的前缀预测的序列。

model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has so it and the time '
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has i have the time tra'
model.predict('it has', 20, data.vocab, trainer.state.params)
'it has is a move and a mov'
model.predict('it has', 20, data.vocab)
'it has t t t t t t t t t t'

10.2.6. 总结

与LSTM相比,GRU实现了相似的性能,但计算上更轻量。通常,与简单RNN相比,像LSTM和GRU这样的门控RNN能更好地捕捉时间步距离较大的序列的依赖关系。当重置门被激活时,GRU包含基本RNN作为其极端情况。它们还可以通过激活更新门来跳过子序列。

10.2.7. 练习

  1. 假设我们只想使用时间步 \(t'\) 的输入来预测时间步 \(t > t'\) 的输出。对于每个时间步,重置门和更新门的最佳值是什么?

  2. 调整超参数并分析它们对运行时间、困惑度和输出序列的影响。

  3. 比较 rnn.RNNrnn.GRU 实现的运行时间、困惑度和输出字符串。

  4. 如果你只实现GRU的一部分,例如,只有一个重置门或只有一个更新门,会发生什么?